`ultralytics 8.2.73` Meta SAM2 Refactor (#14867)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/14980/head v8.2.73
Laughing 6 months ago committed by GitHub
parent bea4c93278
commit 5d9046abda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 24
      docs/en/reference/models/sam/build.md
  2. 4
      docs/en/reference/models/sam/model.md
  3. 72
      docs/en/reference/models/sam/modules/blocks.md
  4. 4
      docs/en/reference/models/sam/modules/decoders.md
  5. 24
      docs/en/reference/models/sam/modules/encoders.md
  6. 20
      docs/en/reference/models/sam/modules/memory_attention.md
  7. 8
      docs/en/reference/models/sam/modules/sam.md
  8. 4
      docs/en/reference/models/sam/modules/tiny_encoder.md
  9. 52
      docs/en/reference/models/sam/modules/utils.md
  10. 8
      docs/en/reference/models/sam/predict.md
  11. 36
      docs/en/reference/models/sam2/build.md
  12. 16
      docs/en/reference/models/sam2/model.md
  13. 16
      docs/en/reference/models/sam2/modules/decoders.md
  14. 28
      docs/en/reference/models/sam2/modules/encoders.md
  15. 20
      docs/en/reference/models/sam2/modules/memory_attention.md
  16. 16
      docs/en/reference/models/sam2/modules/sam2.md
  17. 56
      docs/en/reference/models/sam2/modules/sam2_blocks.md
  18. 44
      docs/en/reference/models/sam2/modules/utils.md
  19. 16
      docs/en/reference/models/sam2/predict.md
  20. 14
      mkdocs.yml
  21. 5
      ultralytics/__init__.py
  22. 3
      ultralytics/models/__init__.py
  23. 4
      ultralytics/models/sam/__init__.py
  24. 48
      ultralytics/models/sam/amg.py
  25. 209
      ultralytics/models/sam/build.py
  26. 120
      ultralytics/models/sam/model.py
  27. 1131
      ultralytics/models/sam/modules/blocks.py
  28. 413
      ultralytics/models/sam/modules/decoders.py
  29. 831
      ultralytics/models/sam/modules/encoders.py
  30. 79
      ultralytics/models/sam/modules/memory_attention.py
  31. 903
      ultralytics/models/sam/modules/sam.py
  32. 502
      ultralytics/models/sam/modules/tiny_encoder.py
  33. 209
      ultralytics/models/sam/modules/transformer.py
  34. 108
      ultralytics/models/sam/modules/utils.py
  35. 474
      ultralytics/models/sam/predict.py
  36. 6
      ultralytics/models/sam2/__init__.py
  37. 156
      ultralytics/models/sam2/build.py
  38. 97
      ultralytics/models/sam2/model.py
  39. 1
      ultralytics/models/sam2/modules/__init__.py
  40. 305
      ultralytics/models/sam2/modules/decoders.py
  41. 332
      ultralytics/models/sam2/modules/encoders.py
  42. 804
      ultralytics/models/sam2/modules/sam2.py
  43. 715
      ultralytics/models/sam2/modules/sam2_blocks.py
  44. 177
      ultralytics/models/sam2/predict.py

@ -1,6 +1,6 @@
---
description: Discover detailed instructions for building various Segment Anything Model (SAM) architectures with Ultralytics, including SAM ViT and Mobile-SAM.
keywords: Ultralytics, SAM model, Segment Anything Model, SAM ViT, Mobile-SAM, model building, deep learning, AI
description: Discover detailed instructions for building various Segment Anything Model (SAM) and Segment Anything Model 2 (SAM 2) architectures with Ultralytics, including SAM ViT and Mobile-SAM.
keywords: Ultralytics, SAM model, Segment Anything Model, SAM 2 model, Segment Anything Model 2, SAM ViT, Mobile-SAM, model building, deep learning, AI
---
# Reference for `ultralytics/models/sam/build.py`
@ -27,10 +27,30 @@ keywords: Ultralytics, SAM model, Segment Anything Model, SAM ViT, Mobile-SAM, m
<br><br><hr><br>
## ::: ultralytics.models.sam.build.build_sam2_t
<br><br><hr><br>
## ::: ultralytics.models.sam.build.build_sam2_s
<br><br><hr><br>
## ::: ultralytics.models.sam.build.build_sam2_b
<br><br><hr><br>
## ::: ultralytics.models.sam.build.build_sam2_l
<br><br><hr><br>
## ::: ultralytics.models.sam.build._build_sam
<br><br><hr><br>
## ::: ultralytics.models.sam.build._build_sam2
<br><br><hr><br>
## ::: ultralytics.models.sam.build.build_sam
<br><br>

@ -1,6 +1,6 @@
---
description: Explore the SAM (Segment Anything Model) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities.
keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset
description: Explore the SAM (Segment Anything Model) and SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities.
keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset
---
# Reference for `ultralytics/models/sam/model.py`

@ -0,0 +1,72 @@
---
description: Explore detailed documentation of various SAM and SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository.
keywords: Ultralytics, SAM encoder, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool
---
# Reference for `ultralytics/models/sam/modules/blocks.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/blocks.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam.modules.blocks.DropPath
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.MaskDownSampler
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.CXBlock
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.Fuser
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.SAM2TwoWayAttentionBlock
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.SAM2TwoWayTransformer
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.RoPEAttention
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.MultiScaleAttention
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.MultiScaleBlock
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.PositionEmbeddingSine
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.PositionEmbeddingRandom
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.Block
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.REAttention
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.PatchEmbed
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.blocks.do_pool
<br><br>

@ -13,4 +13,8 @@ keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architect
## ::: ultralytics.models.sam.modules.decoders.MaskDecoder
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.decoders.SAM2MaskDecoder
<br><br>

@ -19,34 +19,18 @@ keywords: Ultralytics, SAM encoder, ImageEncoderViT, PromptEncoder, PositionEmbe
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.PositionEmbeddingRandom
## ::: ultralytics.models.sam.modules.encoders.MemoryEncoder
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.Block
## ::: ultralytics.models.sam.modules.encoders.ImageEncoder
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.Attention
## ::: ultralytics.models.sam.modules.encoders.FpnNeck
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.PatchEmbed
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.window_partition
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.window_unpartition
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.get_rel_pos
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.encoders.add_decomposed_rel_pos
## ::: ultralytics.models.sam.modules.encoders.Hiera
<br><br>

@ -0,0 +1,20 @@
---
description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository.
keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention
---
# Reference for `ultralytics/models/sam/modules/memory_attention.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/memory_attention.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam.modules.memory_attention.MemoryAttentionLayer
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.memory_attention.MemoryAttention
<br><br>

@ -1,6 +1,6 @@
---
description: Discover the Ultralytics SAM module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
description: Discover the Ultralytics SAM and SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
keywords: Ultralytics, SAM Module, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
---
# Reference for `ultralytics/models/sam/modules/sam.py`
@ -13,4 +13,8 @@ keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask deco
## ::: ultralytics.models.sam.modules.sam.SAMModel
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.sam.SAM2Model
<br><br>

@ -47,10 +47,6 @@ keywords: Ultralytics, TinyViT, Conv2d_BN, PatchEmbed, MBConv, Attention, PyTorc
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.tiny_encoder.LayerNorm2d
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.tiny_encoder.TinyViT
<br><br>

@ -0,0 +1,52 @@
---
description: Explore the detailed API reference for Ultralytics SAM and SAM 2 models.
keywords: Ultralytics, SAM, SAM 2, API Reference, models, window partition, data processing, YOLO
---
# Reference for `ultralytics/models/sam/modules/utils.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/utils.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam.modules.utils.select_closest_cond_frames
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.get_1d_sine_pe
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.init_t_xy
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.compute_axial_cis
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.reshape_for_broadcast
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.apply_rotary_enc
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.window_partition
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.window_unpartition
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.get_rel_pos
<br><br><hr><br>
## ::: ultralytics.models.sam.modules.utils.add_decomposed_rel_pos
<br><br>

@ -1,6 +1,6 @@
---
description: Explore Ultralytics SAM Predictor for advanced, real-time image segmentation using the Segment Anything Model (SAM). Complete implementation details and auxiliary utilities.
keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference
description: Explore Ultralytics SAM and SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model (SAM and SAM 2). Complete implementation details and auxiliary utilities.
keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference
---
# Reference for `ultralytics/models/sam/predict.py`
@ -13,4 +13,8 @@ keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-tim
## ::: ultralytics.models.sam.predict.Predictor
<br><br><hr><br>
## ::: ultralytics.models.sam.predict.SAM2Predictor
<br><br>

@ -1,36 +0,0 @@
---
description: Discover detailed instructions for building various Segment Anything Model 2 (SAM 2) architectures with Ultralytics.
keywords: Ultralytics, SAM 2 model, Segment Anything Model 2, SAM, model building, deep learning, AI
---
# Reference for `ultralytics/models/sam2/build.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/build.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.build.build_sam2_t
<br><br><hr><br>
## ::: ultralytics.models.sam2.build.build_sam2_s
<br><br><hr><br>
## ::: ultralytics.models.sam2.build.build_sam2_b
<br><br><hr><br>
## ::: ultralytics.models.sam2.build.build_sam2_l
<br><br><hr><br>
## ::: ultralytics.models.sam2.build._build_sam2
<br><br><hr><br>
## ::: ultralytics.models.sam2.build.build_sam2
<br><br>

@ -1,16 +0,0 @@
---
description: Explore the SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities.
keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset
---
# Reference for `ultralytics/models/sam2/model.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/model.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.model.SAM2
<br><br>

@ -1,16 +0,0 @@
---
description: Explore the MaskDecoder and MLP modules in Ultralytics for efficient mask prediction using transformer architecture. Detailed attributes, functionalities, and implementation.
keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architecture, mask prediction, neural networks, PyTorch modules
---
# Reference for `ultralytics/models/sam2/modules/decoders.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/decoders.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.modules.decoders.MaskDecoder
<br><br>

@ -1,28 +0,0 @@
---
description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
---
# Reference for `ultralytics/models/sam2/modules/encoders.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/encoders.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.modules.encoders.MemoryEncoder
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.encoders.ImageEncoder
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.encoders.FpnNeck
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.encoders.Hiera
<br><br>

@ -1,20 +0,0 @@
---
description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository.
keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention
---
# Reference for `ultralytics/models/sam2/modules/memory_attention.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/memory_attention.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttentionLayer
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttention
<br><br>

@ -1,16 +0,0 @@
---
description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
---
# Reference for `ultralytics/models/sam2/modules/sam2.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.modules.sam2.SAM2Model
<br><br>

@ -1,56 +0,0 @@
---
description: Explore detailed documentation of various SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository.
keywords: Ultralytics, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool
---
# Reference for `ultralytics/models/sam2/modules/sam2_blocks.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2_blocks.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.DropPath
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.MaskDownSampler
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.CXBlock
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.Fuser
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayAttentionBlock
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayTransformer
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.RoPEAttention
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleAttention
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleBlock
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.PositionEmbeddingSine
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.sam2_blocks.do_pool
<br><br>

@ -1,44 +0,0 @@
---
description: Explore the detailed API reference for Ultralytics SAM 2 models.
keywords: Ultralytics, SAM 2, API Reference, models, window partition, data processing, YOLO
---
# Reference for `ultralytics/models/sam2/modules/utils.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/utils.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.modules.utils.select_closest_cond_frames
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.utils.get_1d_sine_pe
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.utils.init_t_xy
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.utils.compute_axial_cis
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.utils.reshape_for_broadcast
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.utils.apply_rotary_enc
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.utils.window_partition
<br><br><hr><br>
## ::: ultralytics.models.sam2.modules.utils.window_unpartition
<br><br>

@ -1,16 +0,0 @@
---
description: Explore Ultralytics SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model 2 (SAM 2). Complete implementation details and auxiliary utilities.
keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference
---
# Reference for `ultralytics/models/sam2/predict.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/predict.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.models.sam2.predict.SAM2Predictor
<br><br>

@ -503,23 +503,15 @@ nav:
- build: reference/models/sam/build.md
- model: reference/models/sam/model.md
- modules:
- blocks: reference/models/sam/modules/blocks.md
- decoders: reference/models/sam/modules/decoders.md
- encoders: reference/models/sam/modules/encoders.md
- memory_attention: reference/models/sam/modules/memory_attention.md
- sam: reference/models/sam/modules/sam.md
- tiny_encoder: reference/models/sam/modules/tiny_encoder.md
- transformer: reference/models/sam/modules/transformer.md
- utils: reference/models/sam/modules/utils.md
- predict: reference/models/sam/predict.md
- sam2:
- build: reference/models/sam2/build.md
- model: reference/models/sam2/model.md
- modules:
- decoders: reference/models/sam2/modules/decoders.md
- encoders: reference/models/sam2/modules/encoders.md
- memory_attention: reference/models/sam2/modules/memory_attention.md
- sam2: reference/models/sam2/modules/sam2.md
- sam2_blocks: reference/models/sam2/modules/sam2_blocks.md
- utils: reference/models/sam2/modules/utils.md
- predict: reference/models/sam2/predict.md
- utils:
- loss: reference/models/utils/loss.md
- ops: reference/models/utils/ops.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.72"
__version__ = "8.2.73"
import os
@ -8,7 +8,7 @@ import os
os.environ["OMP_NUM_THREADS"] = "1" # reduce CPU utilization during training
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import NAS, RTDETR, SAM, SAM2, YOLO, FastSAM, YOLOWorld
from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
from ultralytics.utils import ASSETS, SETTINGS
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
@ -21,7 +21,6 @@ __all__ = (
"YOLOWorld",
"NAS",
"SAM",
"SAM2",
"FastSAM",
"RTDETR",
"checks",

@ -4,7 +4,6 @@ from .fastsam import FastSAM
from .nas import NAS
from .rtdetr import RTDETR
from .sam import SAM
from .sam2 import SAM2
from .yolo import YOLO, YOLOWorld
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import SAM
from .predict import Predictor
from .predict import Predictor, SAM2Predictor
__all__ = "SAM", "Predictor" # tuple or list
__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list

@ -11,7 +11,7 @@ import torch
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
"""Return a boolean tensor indicating if boxes are near the crop edge."""
"""Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
@ -22,7 +22,7 @@ def is_box_near_crop_edge(
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
"""Yield batches of data from the input arguments."""
"""Yields batches of data from input arguments with specified batch size for efficient processing."""
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
"""
Computes the stability score for a batch of masks.
The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
and low values.
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
high and low values.
Args:
masks (torch.Tensor): Batch of predicted mask logits.
mask_threshold (float): Threshold value for creating binary masks.
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
Returns:
(torch.Tensor): Stability scores for each mask in the batch.
Notes:
- One mask is always contained inside the other.
- Save memory by preventing unnecessary cast to torch.int64
- Memory is saved by preventing unnecessary cast to torch.int64.
Examples:
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
>>> mask_threshold = 0.5
>>> threshold_offset = 0.1
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
"""
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
"""Generate point grids for all crop layers."""
"""Generates point grids for multiple crop layers with varying scales and densities."""
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes.
Each layer has (2**i)**2 boxes for the ith layer.
"""
"""Generates crop boxes of varying sizes for multi-scale image processing, with layered overlapping regions."""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
@ -99,7 +109,7 @@ def generate_crop_boxes(
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop bounding boxes by adding the crop box offset."""
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
"""Uncrop points by adding the crop box offset."""
"""Uncrop points by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
"""Uncrop masks by padding them to the original image size."""
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
@ -130,7 +140,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
"""Removes small disconnected regions or holes in a mask based on area threshold and mode."""
import cv2 # type: ignore
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks.
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
"""Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)

@ -13,14 +13,15 @@ import torch
from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import SAMModel
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
from .modules.sam import SAM2Model, SAMModel
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) h-size model."""
"""Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
@ -31,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
def build_sam_vit_l(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) l-size model."""
"""Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
@ -42,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
def build_sam_vit_b(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) b-size model."""
"""Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
@ -53,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
def build_mobile_sam(checkpoint=None):
"""Build and return Mobile Segment Anything Model (Mobile-SAM)."""
"""Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
return _build_sam(
encoder_embed_dim=[64, 128, 160, 320],
encoder_depth=[2, 2, 6, 2],
@ -64,10 +65,85 @@ def build_mobile_sam(checkpoint=None):
)
def build_sam2_t(checkpoint=None):
"""Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=96,
encoder_stages=[1, 2, 7, 2],
encoder_num_heads=1,
encoder_global_att_blocks=[5, 7, 9],
encoder_window_spec=[8, 4, 14, 7],
encoder_backbone_channel_list=[768, 384, 192, 96],
checkpoint=checkpoint,
)
def build_sam2_s(checkpoint=None):
"""Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=96,
encoder_stages=[1, 2, 11, 2],
encoder_num_heads=1,
encoder_global_att_blocks=[7, 10, 13],
encoder_window_spec=[8, 4, 14, 7],
encoder_backbone_channel_list=[768, 384, 192, 96],
checkpoint=checkpoint,
)
def build_sam2_b(checkpoint=None):
"""Builds and returns a SAM2 base-size model with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=112,
encoder_stages=[2, 3, 16, 3],
encoder_num_heads=2,
encoder_global_att_blocks=[12, 16, 20],
encoder_window_spec=[8, 4, 14, 7],
encoder_window_spatial_size=[14, 14],
encoder_backbone_channel_list=[896, 448, 224, 112],
checkpoint=checkpoint,
)
def build_sam2_l(checkpoint=None):
"""Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=144,
encoder_stages=[2, 6, 36, 4],
encoder_num_heads=2,
encoder_global_att_blocks=[23, 33, 43],
encoder_window_spec=[8, 4, 16, 8],
encoder_backbone_channel_list=[1152, 576, 288, 144],
checkpoint=checkpoint,
)
def _build_sam(
encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
mobile_sam=False,
):
"""Builds the selected SAM model architecture."""
"""
Builds a Segment Anything Model (SAM) with specified encoder parameters.
Args:
encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
encoder_depth (int | List[int]): Depth of the encoder.
encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
checkpoint (str | None): Path to the model checkpoint file.
mobile_sam (bool): Whether to build a Mobile-SAM model.
Returns:
(SAMModel): A Segment Anything Model instance with the specified architecture.
Examples:
>>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
>>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
"""
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
@ -139,16 +215,131 @@ def _build_sam(
return sam
def _build_sam2(
encoder_embed_dim=1280,
encoder_stages=[2, 6, 36, 4],
encoder_num_heads=2,
encoder_global_att_blocks=[7, 15, 23, 31],
encoder_backbone_channel_list=[1152, 576, 288, 144],
encoder_window_spatial_size=[7, 7],
encoder_window_spec=[8, 4, 16, 8],
checkpoint=None,
):
"""
Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
Args:
encoder_embed_dim (int): Embedding dimension for the encoder.
encoder_stages (List[int]): Number of blocks in each stage of the encoder.
encoder_num_heads (int): Number of attention heads in the encoder.
encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
Returns:
(SAM2Model): A configured and initialized SAM2 model.
Examples:
>>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
>>> sam2_model.eval()
"""
image_encoder = ImageEncoder(
trunk=Hiera(
embed_dim=encoder_embed_dim,
num_heads=encoder_num_heads,
stages=encoder_stages,
global_att_blocks=encoder_global_att_blocks,
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
window_spec=encoder_window_spec,
),
neck=FpnNeck(
d_model=256,
backbone_channel_list=encoder_backbone_channel_list,
fpn_top_down_levels=[2, 3],
fpn_interp_model="nearest",
),
scalp=1,
)
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
memory_encoder = MemoryEncoder(out_dim=64)
sam2 = SAM2Model(
image_encoder=image_encoder,
memory_attention=memory_attention,
memory_encoder=memory_encoder,
num_maskmem=7,
image_size=1024,
sigmoid_scale_for_mem_enc=20.0,
sigmoid_bias_for_mem_enc=-10.0,
use_mask_input_as_output_without_sam=True,
directly_add_no_mem_embed=True,
use_high_res_features_in_sam=True,
multimask_output_in_sam=True,
iou_prediction_use_sigmoid=True,
use_obj_ptrs_in_encoder=True,
add_tpos_enc_to_obj_ptrs=True,
only_obj_ptrs_in_the_past_for_eval=True,
pred_obj_scores=True,
pred_obj_scores_mlp=True,
fixed_no_obj_ptr=True,
multimask_output_for_tracking=True,
use_multimask_token_for_obj_ptr=True,
multimask_min_pt_num=0,
multimask_max_pt_num=1,
use_mlp_for_obj_ptr_proj=True,
compile_image_encoder=False,
sam_mask_decoder_extra_args=dict(
dynamic_multimask_via_stability=True,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
),
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)["model"]
sam2.load_state_dict(state_dict)
sam2.eval()
return sam2
sam_model_map = {
"sam_h.pt": build_sam_vit_h,
"sam_l.pt": build_sam_vit_l,
"sam_b.pt": build_sam_vit_b,
"mobile_sam.pt": build_mobile_sam,
"sam2_t.pt": build_sam2_t,
"sam2_s.pt": build_sam2_s,
"sam2_b.pt": build_sam2_b,
"sam2_l.pt": build_sam2_l,
}
def build_sam(ckpt="sam_b.pt"):
"""Build a SAM model specified by ckpt."""
"""
Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
Args:
ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
Returns:
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
Raises:
FileNotFoundError: If the provided checkpoint is not a supported SAM model.
Examples:
>>> sam_model = build_sam("sam_b.pt")
>>> sam_model = build_sam("path/to/custom_checkpoint.pt")
Notes:
Supported pre-defined models include:
- SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
- SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
"""
model_builder = None
ckpt = str(ckpt) # to allow Path ckpt types
for k in sam_model_map.keys():

@ -20,27 +20,46 @@ from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
from .predict import Predictor, SAM2Predictor
class SAM(Model):
"""
SAM (Segment Anything Model) interface class.
SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
dataset.
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
boxes, points, or labels, and features zero-shot performance capabilities.
Attributes:
model (torch.nn.Module): The loaded SAM model.
is_sam2 (bool): Indicates whether the model is SAM2 variant.
task (str): The task type, set to "segment" for SAM models.
Methods:
predict: Performs segmentation prediction on the given image or video source.
info: Logs information about the SAM model.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam.predict('image.jpg', points=[[500, 375]])
>>> for r in results:
>>> print(f"Detected {len(r.masks)} masks")
"""
def __init__(self, model="sam_b.pt") -> None:
"""
Initializes the SAM model with a pre-trained model file.
Initializes the SAM (Segment Anything Model) instance.
Args:
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
Examples:
>>> sam = SAM('sam_b.pt')
>>> print(sam.is_sam2)
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
@ -51,30 +70,40 @@ class SAM(Model):
"""
Loads the specified weights into the SAM model.
This method initializes the SAM model with the provided weights file, setting up the model architecture
and loading the pre-trained parameters.
Args:
weights (str): Path to the weights file.
task (str, optional): Task name. Defaults to None.
"""
if self.is_sam2:
from ..sam2.build import build_sam2
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
self.model = build_sam2(weights)
else:
self.model = build_sam(weights)
Examples:
>>> sam = SAM('sam_b.pt')
>>> sam._load('path/to/custom_weights.pt')
"""
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
Performs segmentation prediction on the given image or video source.
Args:
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
points (list, optional): List of points for prompted segmentation. Defaults to None.
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
a numpy.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
points (List[List[float]] | None): List of points for prompted segmentation.
labels (List[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments for prediction.
Returns:
(list): The model predictions.
(List): The model predictions.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam.predict('image.jpg', points=[[500, 375]])
>>> for r in results:
... print(f"Detected {len(r.masks)} masks")
"""
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
kwargs.update(overrides)
@ -83,17 +112,27 @@ class SAM(Model):
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
Alias for the 'predict' method.
Performs segmentation prediction on the given image or video source.
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
for segmentation tasks.
Args:
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
points (list, optional): List of points for prompted segmentation. Defaults to None.
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
object, or a numpy.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
points (List[List[float]] | None): List of points for prompted segmentation.
labels (List[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
Returns:
(list): The model predictions.
(List): The model predictions, typically containing segmentation masks and other relevant information.
Examples:
>>> sam = SAM('sam_b.pt')
>>> results = sam('image.jpg', points=[[500, 375]])
>>> print(f"Detected {len(results[0].masks)} masks")
"""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
@ -101,12 +140,20 @@ class SAM(Model):
"""
Logs information about the SAM model.
This method provides details about the Segment Anything Model (SAM), including its architecture,
parameters, and computational requirements.
Args:
detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
verbose (bool, optional): If True, displays information on the console. Defaults to True.
detailed (bool): If True, displays detailed information about the model layers and operations.
verbose (bool): If True, prints the information to the console.
Returns:
(tuple): A tuple containing the model's information.
(Tuple): A tuple containing the model's information (string representations of the model).
Examples:
>>> sam = SAM('sam_b.pt')
>>> info = sam.info()
>>> print(info[0]) # Print summary information
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@ -116,8 +163,13 @@ class SAM(Model):
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns:
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
(Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
Examples:
>>> sam = SAM('sam_b.pt')
>>> task_map = sam.task_map
>>> print(task_map)
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
"""
from ..sam2.predict import SAM2Predictor
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}

File diff suppressed because it is too large Load Diff

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
from torch import nn
@ -10,12 +10,14 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
class MaskDecoder(nn.Module):
"""
Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict
masks given image and prompt embeddings.
Decoder module for generating masks and their associated quality scores using a transformer architecture.
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
generate mask predictions along with their quality scores.
Attributes:
transformer_dim (int): Channel dimension for the transformer module.
transformer (nn.Module): The transformer module used for mask prediction.
transformer (nn.Module): Transformer module used for mask prediction.
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
iou_token (nn.Embedding): Embedding for the IoU token.
num_mask_tokens (int): Number of mask tokens.
@ -23,6 +25,16 @@ class MaskDecoder(nn.Module):
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
iou_prediction_head (nn.Module): MLP for predicting mask quality.
Methods:
forward: Predicts masks given image and prompt embeddings.
predict_masks: Internal method for mask prediction.
Examples:
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
>>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings,
... dense_prompt_embeddings, multimask_output=True)
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
"""
def __init__(
@ -35,15 +47,20 @@ class MaskDecoder(nn.Module):
iou_head_hidden_dim: int = 256,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a transformer architecture.
Initializes the MaskDecoder module for generating masks and their quality scores.
Args:
transformer_dim (int): the channel dimension of the transformer module
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
activation (nn.Module): the type of activation to use when upscaling masks
iou_head_depth (int): the depth of the MLP used to predict mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
transformer_dim (int): Channel dimension for the transformer module.
transformer (nn.Module): Transformer module used for mask prediction.
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
iou_head_depth (int): Depth of the MLP used to predict mask quality.
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
Examples:
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
>>> print(decoder)
"""
super().__init__()
self.transformer_dim = transformer_dim
@ -77,18 +94,28 @@ class MaskDecoder(nn.Module):
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Predicts masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
image_embeddings (torch.Tensor): Embeddings from the image encoder.
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
multimask_output (bool): Whether to return multiple masks or a single mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
(Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
- masks (torch.Tensor): Batched predicted masks.
- iou_pred (torch.Tensor): Batched predictions of mask quality.
Examples:
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
>>> image_emb = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_emb = torch.rand(1, 2, 256)
>>> dense_emb = torch.rand(1, 256, 64, 64)
>>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
>>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
@ -112,11 +139,7 @@ class MaskDecoder(nn.Module):
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predicts masks.
See 'forward' for more details.
"""
"""Predicts masks and quality scores using image and prompt embeddings via transformer architecture."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
@ -147,3 +170,347 @@ class MaskDecoder(nn.Module):
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
class SAM2MaskDecoder(nn.Module):
"""
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
This class extends the functionality of the MaskDecoder, incorporating additional features such as
high-resolution feature processing, dynamic multimask output, and object score prediction.
Attributes:
transformer_dim (int): Channel dimension of the transformer.
transformer (nn.Module): Transformer used to predict masks.
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
iou_token (nn.Embedding): Embedding for IOU token.
num_mask_tokens (int): Total number of mask tokens.
mask_tokens (nn.Embedding): Embedding for mask tokens.
pred_obj_scores (bool): Whether to predict object scores.
obj_score_token (nn.Embedding): Embedding for object score token.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
output_upscaling (nn.Sequential): Upscaling layers for output.
use_high_res_features (bool): Whether to use high-resolution features.
conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
iou_prediction_head (MLP): MLP for IOU prediction.
pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
Methods:
forward: Predicts masks given image and prompt embeddings.
predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
_get_stability_scores: Computes mask stability scores based on IoU between thresholds.
_dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
Examples:
>>> image_embeddings = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
"""
def __init__(
self,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
use_high_res_features: bool = False,
iou_prediction_use_sigmoid=False,
dynamic_multimask_via_stability=False,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
use_multimask_token_for_obj_ptr: bool = False,
) -> None:
"""
Initializes the SAM2MaskDecoder module for predicting instance segmentation masks.
This decoder extends the functionality of MaskDecoder, incorporating additional features such as
high-resolution feature processing, dynamic multimask output, and object score prediction.
Args:
transformer_dim (int): Channel dimension of the transformer.
transformer (nn.Module): Transformer used to predict masks.
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
iou_head_depth (int): Depth of the MLP used to predict mask quality.
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
use_high_res_features (bool): Whether to use high-resolution features.
iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
pred_obj_scores (bool): Whether to predict object scores.
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
Examples:
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
>>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
>>> print(decoder)
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.pred_obj_scores = pred_obj_scores
if self.pred_obj_scores:
self.obj_score_token = nn.Embedding(1, transformer_dim)
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.use_high_res_features = use_high_res_features
if use_high_res_features:
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
self.output_hypernetworks_mlps = nn.ModuleList(
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
)
self.iou_prediction_head = MLP(
transformer_dim,
iou_head_hidden_dim,
self.num_mask_tokens,
iou_head_depth,
sigmoid=iou_prediction_use_sigmoid,
)
if self.pred_obj_scores:
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
if pred_obj_scores_mlp:
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
# When outputting a single mask, optionally we can dynamically fall back to the best
# multimask output token if the single mask output token gives low stability scores.
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predicts masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
multimask_output (bool): Whether to return multiple masks or a single mask.
repeat_image (bool): Flag to repeat the image embeddings.
high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
Returns:
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
- masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
- iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
- sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
- object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
Examples:
>>> image_embeddings = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
"""
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
repeat_image=repeat_image,
high_res_features=high_res_features,
)
# Select the correct mask or masks for output
if multimask_output:
masks = masks[:, 1:, :, :]
iou_pred = iou_pred[:, 1:]
elif self.dynamic_multimask_via_stability and not self.training:
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
masks = masks[:, 0:1, :, :]
iou_pred = iou_pred[:, 0:1]
if multimask_output and self.use_multimask_token_for_obj_ptr:
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
else:
# Take the mask output token. Here we *always* use the token for single mask output.
# At test time, even if we track after 1-click (and using multimask_output=True),
# we still take the single mask token here. The rationale is that we always track
# after multiple clicks during training, so the past tokens seen during training
# are always the single mask token (and we'll let it be the object-memory token).
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
# Prepare output
return masks, iou_pred, sam_tokens_out, object_score_logits
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts instance segmentation masks from image and prompt embeddings using a transformer."""
# Concatenate output tokens
s = 0
if self.pred_obj_scores:
output_tokens = torch.cat(
[
self.obj_score_token.weight,
self.iou_token.weight,
self.mask_tokens.weight,
],
dim=0,
)
s = 1
else:
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
if repeat_image:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
assert image_embeddings.shape[0] == tokens.shape[0]
src = image_embeddings
src = src + dense_prompt_embeddings
assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, s, :]
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
if not self.use_high_res_features:
upscaled_embedding = self.output_upscaling(src)
else:
dc1, ln1, act1, dc2, act2 = self.output_upscaling
feat_s0, feat_s1 = high_res_features
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
if self.pred_obj_scores:
assert s == 1
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
else:
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
return masks, iou_pred, mask_tokens_out, object_score_logits
def _get_stability_scores(self, mask_logits):
"""Computes mask stability scores based on IoU between upper and lower thresholds."""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
Dynamically selects the most stable mask output based on stability scores and IoU predictions.
This method is used when outputting a single mask. If the stability score from the current single-mask
output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
(based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
for both clicking and tracking scenarios.
Args:
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
batch size, N is number of masks (typically 4), and H, W are mask dimensions.
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
Returns:
(Tuple[torch.Tensor, torch.Tensor]):
- mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
- iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
Examples:
>>> decoder = SAM2MaskDecoder(...)
>>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each
>>> all_iou_scores = torch.rand(2, 4)
>>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
>>> print(mask_logits.shape, iou_scores.shape)
torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
best_multimask_logits = best_multimask_logits.unsqueeze(1)
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out

@ -1,30 +1,48 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import Any, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.nn.modules import LayerNorm2d, MLPBlock
from ultralytics.nn.modules import LayerNorm2d
from .blocks import (
Block,
CXBlock,
Fuser,
MaskDownSampler,
MultiScaleBlock,
PatchEmbed,
PositionEmbeddingRandom,
PositionEmbeddingSine,
)
class ImageEncoderViT(nn.Module):
"""
An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
The encoded patches are then processed through a neck to generate the final encoded representation.
An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
This class and its supporting functions below lightly adapted from the ViTDet backbone available at
https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.
This class processes images by splitting them into patches, applying transformer blocks, and generating a final
encoded representation through a neck module.
Attributes:
img_size (int): Dimension of input images, assumed to be square.
patch_embed (PatchEmbed): Module for patch embedding.
pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
neck (nn.Sequential): Neck module to further process the output.
Methods:
forward: Processes input through patch embedding, positional embedding, blocks, and neck.
Examples:
>>> import torch
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
>>> input_image = torch.randn(1, 3, 224, 224)
>>> output = encoder(input_image)
>>> print(output.shape)
"""
def __init__(
@ -47,22 +65,38 @@ class ImageEncoderViT(nn.Module):
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
img_size (int): Input image size, assumed to be square.
patch_size (int): Size of image patches.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
embed_dim (int): Dimension of patch embeddings.
depth (int): Number of transformer blocks.
num_heads (int): Number of attention heads in each block.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
out_chans (int): Number of output channels from the neck module.
qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
norm_layer (Type[nn.Module]): Type of normalization layer to use.
act_layer (Type[nn.Module]): Type of activation layer to use.
use_abs_pos (bool): If True, uses absolute positional embeddings.
use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
window_size (int): Size of attention window for windowed attention blocks.
global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
Attributes:
img_size (int): Dimension of input images.
patch_embed (PatchEmbed): Module for patch embedding.
pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
blocks (nn.ModuleList): List of transformer blocks.
neck (nn.Sequential): Neck module for final processing.
Examples:
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
>>> input_image = torch.randn(1, 3, 224, 224)
>>> output = encoder(input_image)
>>> print(output.shape)
"""
super().__init__()
self.img_size = img_size
@ -114,9 +148,7 @@ class ImageEncoderViT(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
and neck.
"""
"""Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
@ -127,8 +159,7 @@ class ImageEncoderViT(nn.Module):
class PromptEncoder(nn.Module):
"""
Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
produces both sparse and dense embeddings for the input prompts.
Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
Attributes:
embed_dim (int): Dimension of the embeddings.
@ -137,10 +168,23 @@ class PromptEncoder(nn.Module):
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
num_point_embeddings (int): Number of point embeddings for different types of points.
point_embeddings (nn.ModuleList): List of point embeddings.
not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
mask_input_size (Tuple[int, int]): Size of the input mask.
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
Methods:
get_dense_pe: Returns the positional encoding used to encode point prompts.
forward: Embeds different types of prompts, returning both sparse and dense embeddings.
Examples:
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
>>> boxes = torch.rand(1, 2, 2)
>>> masks = torch.rand(1, 1, 256, 256)
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
def __init__(
@ -152,18 +196,37 @@ class PromptEncoder(nn.Module):
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Initializes the PromptEncoder module for encoding various types of prompts.
This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
producing both sparse and dense embeddings.
Args:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
embed_dim (int): The dimension of the embeddings.
image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
mask_in_chans (int): The number of hidden channels used for encoding input masks.
activation (Type[nn.Module]): The activation function to use when encoding input masks.
Attributes:
embed_dim (int): Dimension of the embeddings.
input_image_size (Tuple[int, int]): Size of the input image as (H, W).
image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
num_point_embeddings (int): Number of point embeddings for different types of points.
point_embeddings (nn.ModuleList): List of point embeddings.
not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
mask_input_size (Tuple[int, int]): Size of the input mask.
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
Examples:
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
>>> boxes = torch.rand(1, 2, 2)
>>> masks = torch.rand(1, 1, 256, 256)
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
super().__init__()
self.embed_dim = embed_dim
@ -190,16 +253,25 @@ class PromptEncoder(nn.Module):
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
image encoding.
Returns the dense positional encoding used for encoding point prompts.
This method generates a positional encoding for a dense set of points matching the shape of the image
encoding. The encoding is used to provide spatial information to the model when processing point prompts.
Returns:
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
height and width of the image embedding size, respectively.
Examples:
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> dense_pe = prompt_encoder.get_dense_pe()
>>> print(dense_pe.shape)
torch.Size([1, 256, 64, 64])
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
"""Embeds point prompts."""
"""Embeds point prompts by applying positional encoding and label-specific embeddings."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
@ -216,7 +288,7 @@ class PromptEncoder(nn.Module):
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
"""Embeds box prompts by applying positional encoding and adding corner embeddings."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
@ -225,7 +297,7 @@ class PromptEncoder(nn.Module):
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
"""Embeds mask inputs by downscaling and processing through convolutional layers."""
return self.mask_downscaling(masks)
@staticmethod
@ -258,14 +330,25 @@ class PromptEncoder(nn.Module):
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
boxes (torch.Tensor, None): boxes to embed
masks (torch.Tensor, None): masks to embed
points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
shape (B, N).
boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
by the number of input points and boxes.
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
(Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
- sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
- dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
Examples:
>>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
>>> boxes = torch.rand(1, 2, 2, 2)
>>> masks = torch.rand(1, 1, 256, 256)
>>> sparse_emb, dense_emb = encoder(points, boxes, masks)
>>> print(sparse_emb.shape, dense_emb.shape)
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
@ -287,319 +370,421 @@ class PromptEncoder(nn.Module):
return sparse_embeddings, dense_embeddings
class PositionEmbeddingRandom(nn.Module):
"""Positional encoding using random spatial frequencies."""
class MemoryEncoder(nn.Module):
"""
Encodes pixel features and masks into a memory representation for efficient image segmentation.
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
"""Initializes a position embedding using random spatial frequencies."""
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# Outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks."""
This class processes pixel-level features and masks, fusing them to generate encoded memory representations
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
Attributes:
mask_downsampler (MaskDownSampler): Module for downsampling input masks.
pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
fuser (Fuser): Module for fusing pixel features and masks.
position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
Methods:
forward: Processes input pixel features and masks to generate encoded memory representations.
Examples:
>>> import torch
>>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
>>> pix_feat = torch.randn(1, 256, 64, 64)
>>> masks = torch.randn(1, 1, 64, 64)
>>> encoded_feat, pos = encoder(pix_feat, masks)
>>> print(encoded_feat.shape, pos.shape)
torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int), None): Input resolution for calculating the relative
positional parameter size.
"""
out_dim,
in_dim=256, # in_dim of pix_feats
):
"""Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
self.window_size = window_size
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
def forward(
self,
pix_feat: torch.Tensor,
masks: torch.Tensor,
skip_mask_sigmoid: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Processes pixel features and masks to generate encoded memory representations for segmentation."""
if not skip_mask_sigmoid:
masks = F.sigmoid(masks)
masks = self.mask_downsampler(masks)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = shortcut + x
return x + self.mlp(self.norm2(x))
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
return {"vision_features": x, "vision_pos_enc": [pos]}
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Initialize Attention module.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int), None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
class ImageEncoder(nn.Module):
"""
Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
This class combines a trunk network for feature extraction with a neck network for feature refinement
and positional encoding generation. It can optionally discard the lowest resolution features.
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert input_size is not None, "Input size must be provided if using relative positional encoding."
# Initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
Attributes:
trunk (nn.Module): The trunk network for initial feature extraction.
neck (nn.Module): The neck network for feature refinement and positional encoding generation.
scalp (int): Number of lowest resolution feature levels to discard.
Methods:
forward: Processes the input image through the trunk and neck networks.
Examples:
>>> trunk = SomeTrunkNetwork()
>>> neck = SomeNeckNetwork()
>>> encoder = ImageEncoder(trunk, neck, scalp=1)
>>> image = torch.randn(1, 3, 224, 224)
>>> output = encoder(image)
>>> print(output.keys())
dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
def __init__(
self,
trunk: nn.Module,
neck: nn.Module,
scalp: int = 0,
):
"""Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
super().__init__()
self.trunk = trunk
self.neck = neck
self.scalp = scalp
assert (
self.trunk.channel_list == self.neck.backbone_channel_list
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
def forward(self, sample: torch.Tensor):
"""Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
features, pos = self.neck(self.trunk(sample))
if self.scalp > 0:
# Discard the lowest resolution features
features, pos = features[: -self.scalp], pos[: -self.scalp]
src = features[-1]
output = {
"vision_features": src,
"vision_pos_enc": pos,
"backbone_fpn": features,
}
return output
class FpnNeck(nn.Module):
"""
A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
attn = (q * self.scale) @ k.transpose(-2, -1)
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
similar to ViT positional embedding interpolation.
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
Attributes:
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
convs (nn.ModuleList): List of convolutional layers for each backbone level.
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
Methods:
forward: Performs forward pass through the FPN neck.
Examples:
>>> backbone_channels = [64, 128, 256, 512]
>>> fpn_neck = FpnNeck(256, backbone_channels)
>>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
>>> outputs, positions = fpn_neck(inputs)
>>> print(len(outputs), len(positions))
4 4
"""
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
return self.proj(x)
def __init__(
self,
d_model: int,
backbone_channel_list: List[int],
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
fpn_interp_model: str = "bilinear",
fuse_type: str = "sum",
fpn_top_down_levels: Optional[List[int]] = None,
):
"""
Initializes a modified Feature Pyramid Network (FPN) neck.
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
similar to ViT positional embedding interpolation.
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
Args:
d_model (int): Dimension of the model.
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
kernel_size (int): Kernel size for the convolutional layers.
stride (int): Stride for the convolutional layers.
padding (int): Padding for the convolutional layers.
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
Examples:
>>> backbone_channels = [64, 128, 256, 512]
>>> fpn_neck = FpnNeck(256, backbone_channels)
>>> print(fpn_neck)
"""
super().__init__()
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
self.convs = nn.ModuleList()
self.backbone_channel_list = backbone_channel_list
for dim in backbone_channel_list:
current = nn.Sequential()
current.add_module(
"conv",
nn.Conv2d(
in_channels=dim,
out_channels=d_model,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
)
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
self.convs.append(current)
self.fpn_interp_model = fpn_interp_model
assert fuse_type in ["sum", "avg"]
self.fuse_type = fuse_type
# levels to have top-down features in its outputs
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
# have top-down propagation, while outputs of level 0 and level 1 have only
# lateral features from the same backbone level.
if fpn_top_down_levels is None:
# default is to have top-down features on all levels
fpn_top_down_levels = range(len(self.convs))
self.fpn_top_down_levels = list(fpn_top_down_levels)
def forward(self, xs: List[torch.Tensor]):
"""
Performs forward pass through the Feature Pyramid Network (FPN) neck.
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
Args:
xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
Returns:
(Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
- out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
(B, d_model, H, W).
- pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
Examples:
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
>>> outputs, positions = fpn_neck(inputs)
>>> print(len(outputs), len(positions))
4 4
"""
out = [None] * len(self.convs)
pos = [None] * len(self.convs)
assert len(xs) == len(self.convs)
# fpn forward pass
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
prev_features = None
# forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
for i in range(n, -1, -1):
x = xs[i]
lateral_features = self.convs[n - i](x)
if i in self.fpn_top_down_levels and prev_features is not None:
top_down_features = F.interpolate(
prev_features.to(dtype=torch.float32),
scale_factor=2.0,
mode=self.fpn_interp_model,
align_corners=(None if self.fpn_interp_model == "nearest" else False),
antialias=False,
)
prev_features = lateral_features + top_down_features
if self.fuse_type == "avg":
prev_features /= 2
else:
prev_features = lateral_features
x_out = prev_features
out[i] = x_out
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
return out, pos
class Hiera(nn.Module):
"""
Window unpartition into original sequences and removing padding.
Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
with optional pooling and global attention mechanisms.
Returns:
x: unpartitioned sequences with [B, H, W, C].
Attributes:
window_spec (Tuple[int, ...]): Window sizes for each stage.
q_stride (Tuple[int, int]): Downsampling stride between stages.
stage_ends (List[int]): Indices of the last block in each stage.
q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
return_interm_layers (bool): Whether to return intermediate layer outputs.
patch_embed (PatchEmbed): Module for patch embedding.
global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
pos_embed (nn.Parameter): Positional embedding for the background.
pos_embed_window (nn.Parameter): Positional embedding for the window.
blocks (nn.ModuleList): List of MultiScaleBlock modules.
channel_list (List[int]): List of output channel dimensions for each stage.
Methods:
_get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
forward: Performs the forward pass through the Hiera model.
Examples:
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
>>> input_tensor = torch.randn(1, 3, 224, 224)
>>> output_features = model(input_tensor)
>>> for feat in output_features:
... print(feat.shape)
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def __init__(
self,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
drop_path_rate: float = 0.0, # stochastic depth
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
dim_mul: float = 2.0, # dim_mul factor at stage shift
head_mul: float = 2.0, # head_mul factor at stage shift
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
# window size per stage, when not using global att.
window_spec: Tuple[int, ...] = (
8,
4,
14,
7,
),
# global attn in these blocks
global_att_blocks: Tuple[int, ...] = (
12,
16,
20,
),
return_interm_layers=True, # return feats from every stage
):
"""Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
super().__init__()
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of query and key sizes.
assert len(stages) == len(window_spec)
self.window_spec = window_spec
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
depth = sum(stages)
self.q_stride = q_stride
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
assert 0 <= q_pool <= len(self.stage_ends[:-1])
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
self.return_interm_layers = return_interm_layers
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
self.patch_embed = PatchEmbed(
embed_dim=embed_dim,
kernel_size=(7, 7),
stride=(4, 4),
padding=(3, 3),
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
# Which blocks have global att?
self.global_att_blocks = global_att_blocks
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
B, q_h * q_w, k_h * k_w
)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
return attn
cur_stage = 1
self.blocks = nn.ModuleList()
for i in range(depth):
dim_out = embed_dim
# lags by a block, so first block of
# next stage uses an initial window size
# of previous stage and final window size of current stage
window_size = self.window_spec[cur_stage - 1]
class PatchEmbed(nn.Module):
"""Image to Patch Embedding."""
if self.global_att_blocks is not None:
window_size = 0 if i in self.global_att_blocks else window_size
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Initialize PatchEmbed module.
if i - 1 in self.stage_ends:
dim_out = int(embed_dim * dim_mul)
num_heads = int(num_heads * head_mul)
cur_stage += 1
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
drop_path=dpr[i],
q_stride=self.q_stride if i in self.q_pool_blocks else None,
window_size=window_size,
)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
embed_dim = dim_out
self.blocks.append(block)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
self.channel_list = (
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
if return_interm_layers
else [self.blocks[-1].dim_out]
)
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
"""Generates positional embeddings by interpolating and combining window and background embeddings."""
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Performs forward pass through Hiera model, extracting multiscale features from input images."""
x = self.patch_embed(x)
# x: (B, H, W, C)
# Add pos embed
x = x + self._get_pos_embed(x.shape[1:3])
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs

@ -6,11 +6,50 @@ from typing import Optional
import torch
from torch import Tensor, nn
from .sam2_blocks import RoPEAttention
from .blocks import RoPEAttention
class MemoryAttentionLayer(nn.Module):
"""Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks."""
"""
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
generate memory-based attention outputs.
Attributes:
d_model (int): Dimensionality of the model.
dim_feedforward (int): Dimensionality of the feedforward network.
dropout_value (float): Dropout rate for regularization.
self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
linear1 (nn.Linear): First linear layer of the feedforward network.
linear2 (nn.Linear): Second linear layer of the feedforward network.
norm1 (nn.LayerNorm): Layer normalization for self-attention output.
norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
dropout1 (nn.Dropout): Dropout layer after self-attention.
dropout2 (nn.Dropout): Dropout layer after cross-attention.
dropout3 (nn.Dropout): Dropout layer after feedforward network.
activation (nn.ReLU): Activation function for the feedforward network.
pos_enc_at_attn (bool): Flag to add positional encoding at attention.
pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
Methods:
forward: Performs the full memory attention operation on input tensors.
_forward_sa: Performs self-attention on input tensor.
_forward_ca: Performs cross-attention between target and memory tensors.
Examples:
>>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
>>> tgt = torch.randn(1, 100, 256)
>>> memory = torch.randn(1, 100, 64)
>>> pos = torch.randn(1, 100, 256)
>>> query_pos = torch.randn(1, 100, 256)
>>> output = layer(tgt, memory, pos, query_pos)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
def __init__(
self,
@ -21,7 +60,7 @@ class MemoryAttentionLayer(nn.Module):
pos_enc_at_cross_attn_keys: bool = True,
pos_enc_at_cross_attn_queries: bool = False,
):
"""Initializes a MemoryAttentionLayer with self-attention, cross-attention, and feedforward components."""
"""Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
@ -88,7 +127,7 @@ class MemoryAttentionLayer(nn.Module):
query_pos: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
) -> torch.Tensor:
"""Performs self-attention, cross-attention, and MLP operations on input tensors for memory-based attention."""
"""Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
# MLP
@ -99,7 +138,35 @@ class MemoryAttentionLayer(nn.Module):
class MemoryAttention(nn.Module):
"""Memory attention module for processing sequential data with self and cross-attention mechanisms."""
"""
Memory attention module for processing sequential data with self and cross-attention mechanisms.
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
for processing sequential data, particularly useful in transformer-like architectures.
Attributes:
d_model (int): The dimension of the model's hidden state.
layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
num_layers (int): The number of attention layers.
norm (nn.LayerNorm): Layer normalization applied to the output.
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
batch_first (bool): Whether the input tensors are in batch-first format.
Methods:
forward: Processes input tensors through the attention layers.
Examples:
>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])
"""
def __init__(
self,
@ -126,7 +193,7 @@ class MemoryAttention(nn.Module):
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
):
"""Applies self-attention and cross-attention to input tensors, processing through multiple layers."""
"""Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
if isinstance(curr, list):
assert isinstance(curr_pos, list)
assert len(curr) == len(curr_pos) == 1

@ -9,25 +9,48 @@
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import trunc_normal_
from .decoders import MaskDecoder
from ultralytics.nn.modules import MLP
from .blocks import SAM2TwoWayTransformer
from .decoders import MaskDecoder, SAM2MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
from .utils import get_1d_sine_pe, select_closest_cond_frames
# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0
class SAMModel(nn.Module):
"""
SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate
image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by
the mask decoder to predict object masks.
Segment Anything Model (SAM) for object segmentation tasks.
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
and input prompts.
Attributes:
mask_threshold (float): Threshold value for mask prediction.
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
pixel_mean (List[float]): Mean pixel values for image normalization.
pixel_std (List[float]): Standard deviation values for image normalization.
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
Methods:
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
Examples:
>>> image_encoder = ImageEncoderViT(...)
>>> prompt_encoder = PromptEncoder(...)
>>> mask_decoder = MaskDecoder(...)
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
>>> # Further usage depends on SAMPredictor class
Notes:
All forward() operations are implemented in the SAMPredictor class.
"""
mask_threshold: float = 0.0
@ -43,17 +66,22 @@ class SAMModel(nn.Module):
"""
Initialize the SAMModel class to predict object masks from an image and input prompts.
Note:
All forward() operations moved to SAMPredictor.
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
(123.675, 116.28, 103.53).
pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
(58.395, 57.12, 57.375).
pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
pixel_std (List[float]): Std values for normalizing pixels in the input image.
Examples:
>>> image_encoder = ImageEncoderViT(...)
>>> prompt_encoder = PromptEncoder(...)
>>> mask_decoder = MaskDecoder(...)
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
>>> # Further usage depends on SAMPredictor class
Notes:
All forward() operations moved to SAMPredictor.
"""
super().__init__()
self.image_encoder = image_encoder
@ -61,3 +89,846 @@ class SAMModel(nn.Module):
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
class SAM2Model(torch.nn.Module):
"""
SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
for temporal consistency and efficient tracking of objects across frames.
Attributes:
mask_threshold (float): Threshold value for mask prediction.
image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
memory_attention (nn.Module): Module for attending to memory features.
memory_encoder (nn.Module): Encoder for generating memory representations.
num_maskmem (int): Number of accessible memory frames.
image_size (int): Size of input images.
backbone_stride (int): Stride of the backbone network output.
sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
sam_image_embedding_size (int): Size of SAM image embeddings.
sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
obj_ptr_proj (nn.Module): Projection layer for object pointers.
obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
Methods:
forward_image: Processes image batch through encoder to extract multi-level features.
track_step: Performs a single tracking step, updating object masks and memory features.
Examples:
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
>>> image_batch = torch.rand(1, 3, 512, 512)
>>> features = model.forward_image(image_batch)
>>> track_results = model.track_step(0, True, features, None, None, None, {})
"""
mask_threshold: float = 0.0
def __init__(
self,
image_encoder,
memory_attention,
memory_encoder,
num_maskmem=7,
image_size=512,
backbone_stride=16,
sigmoid_scale_for_mem_enc=1.0,
sigmoid_bias_for_mem_enc=0.0,
binarize_mask_from_pts_for_mem_enc=False,
use_mask_input_as_output_without_sam=False,
max_cond_frames_in_attn=-1,
directly_add_no_mem_embed=False,
use_high_res_features_in_sam=False,
multimask_output_in_sam=False,
multimask_min_pt_num=1,
multimask_max_pt_num=1,
multimask_output_for_tracking=False,
use_multimask_token_for_obj_ptr: bool = False,
iou_prediction_use_sigmoid=False,
memory_temporal_stride_for_eval=1,
add_all_frames_to_correct_as_cond=False,
non_overlap_masks_for_mem_enc=False,
use_obj_ptrs_in_encoder=False,
max_obj_ptrs_in_encoder=16,
add_tpos_enc_to_obj_ptrs=True,
proj_tpos_enc_in_obj_ptrs=False,
only_obj_ptrs_in_the_past_for_eval=False,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
fixed_no_obj_ptr: bool = False,
soft_no_obj_ptr: bool = False,
use_mlp_for_obj_ptr_proj: bool = False,
sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False,
):
"""
Initializes the SAM2Model for video object segmentation with memory-based tracking.
Args:
image_encoder (nn.Module): Visual encoder for extracting image features.
memory_attention (nn.Module): Module for attending to memory features.
memory_encoder (nn.Module): Encoder for generating memory representations.
num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
image_size (int): Size of input images.
backbone_stride (int): Stride of the image backbone output.
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
with clicks during evaluation.
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
prompt encoder and mask decoder on frames with mask input.
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
-1 means no limit.
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
first frame.
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
conditioning frames.
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning
frame list.
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
memory encoder during evaluation.
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
cross-attention.
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
the encoder.
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
encoding in object pointers.
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
during evaluation.
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
Examples:
>>> image_encoder = ImageEncoderViT(...)
>>> memory_attention = SAM2TwoWayTransformer(...)
>>> memory_encoder = nn.Sequential(...)
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
>>> image_batch = torch.rand(1, 3, 512, 512)
>>> features = model.forward_image(image_batch)
>>> track_results = model.track_step(0, True, features, None, None, None, {})
"""
super().__init__()
# Part 1: the image backbone
self.image_encoder = image_encoder
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
self.use_high_res_features_in_sam = use_high_res_features_in_sam
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
if use_obj_ptrs_in_encoder:
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
# so that it can be fed into the SAM mask decoder to generate a pointer.
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
if proj_tpos_enc_in_obj_ptrs:
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
# Part 2: memory attention to condition current frame's visual features
# with memories (and obj ptrs) from past frames
self.memory_attention = memory_attention
self.hidden_dim = memory_attention.d_model
# Part 3: memory encoder for the previous frame's outputs
self.memory_encoder = memory_encoder
self.mem_dim = self.hidden_dim
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
# if there is compression of memories along channel dim
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
self.num_maskmem = num_maskmem # Number of memories accessible
# Temporal encoding of the memories
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
# a single token to indicate no memory embedding from previous frames
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
trunc_normal_(self.no_mem_embed, std=0.02)
trunc_normal_(self.no_mem_pos_enc, std=0.02)
self.directly_add_no_mem_embed = directly_add_no_mem_embed
# Apply sigmoid to the output raw mask logits (to turn them from
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
# On frames with mask input, whether to directly output the input mask without
# using a SAM prompt encoder + mask decoder
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
self.multimask_output_in_sam = multimask_output_in_sam
self.multimask_min_pt_num = multimask_min_pt_num
self.multimask_max_pt_num = multimask_max_pt_num
self.multimask_output_for_tracking = multimask_output_for_tracking
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
# and SAM-style mask decoder for the final mask output
self.image_size = image_size
self.backbone_stride = backbone_stride
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
self.pred_obj_scores = pred_obj_scores
self.pred_obj_scores_mlp = pred_obj_scores_mlp
self.fixed_no_obj_ptr = fixed_no_obj_ptr
self.soft_no_obj_ptr = soft_no_obj_ptr
if self.fixed_no_obj_ptr:
assert self.pred_obj_scores
assert self.use_obj_ptrs_in_encoder
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
trunc_normal_(self.no_obj_ptr, std=0.02)
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
self._build_sam_heads()
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self.max_cond_frames_in_attn = max_cond_frames_in_attn
# Model compilation
if compile_image_encoder:
# Compile the forward function (not the full module) to allow loading checkpoints.
print("Image encoder compilation is enabled. First forward pass will be slow.")
self.image_encoder.forward = torch.compile(
self.image_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
@property
def device(self):
"""Returns the device on which the model's parameters are stored."""
return next(self.parameters()).device
def forward(self, *args, **kwargs):
"""Processes image and prompt inputs to generate object masks and scores in video sequences."""
raise NotImplementedError(
"Please use the corresponding methods in SAM2VideoPredictor for inference."
"See notebooks/video_predictor_example.ipynb for an example."
)
def _build_sam_heads(self):
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
self.sam_prompt_embed_dim = self.hidden_dim
self.sam_image_embedding_size = self.image_size // self.backbone_stride
# build PromptEncoder and MaskDecoder from SAM
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
self.sam_prompt_encoder = PromptEncoder(
embed_dim=self.sam_prompt_embed_dim,
image_embedding_size=(
self.sam_image_embedding_size,
self.sam_image_embedding_size,
),
input_image_size=(self.image_size, self.image_size),
mask_in_chans=16,
)
self.sam_mask_decoder = SAM2MaskDecoder(
num_multimask_outputs=3,
transformer=SAM2TwoWayTransformer(
depth=2,
embedding_dim=self.sam_prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.sam_prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
use_high_res_features=self.use_high_res_features_in_sam,
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
pred_obj_scores=self.pred_obj_scores,
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
**(self.sam_mask_decoder_extra_args or {}),
)
if self.use_obj_ptrs_in_encoder:
# a linear projection on SAM output tokens to turn them into object pointers
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
if self.use_mlp_for_obj_ptr_proj:
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
else:
self.obj_ptr_proj = torch.nn.Identity()
if self.proj_tpos_enc_in_obj_ptrs:
# a linear projection on temporal positional encoding in object pointers to
# avoid potential interference with spatial positional encoding
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
else:
self.obj_ptr_tpos_proj = torch.nn.Identity()
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
Forward pass through SAM prompt encoders and mask heads.
This method processes image features and optional point/mask inputs to generate object masks and scores.
Args:
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
pixel-unit coordinates in (x, y) format for P input points.
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
0 means negative clicks, and -1 means padding.
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
same spatial size as the image.
high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
for SAM decoder.
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
output only 1 mask and its IoU estimate.
Returns:
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
object_score_logits: Tensor of shape (B,) with object score logits.
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
Examples:
>>> backbone_features = torch.rand(1, 256, 32, 32)
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
>>> mask_inputs = torch.rand(1, 1, 512, 512)
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
"""
B = backbone_features.size(0)
device = backbone_features.device
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# a) Handle point prompts
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"]
sam_point_labels = point_inputs["point_labels"]
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# If no points are provide, pad with an empty point (with label -1)
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# b) Handle mask prompts
if mask_inputs is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
else:
sam_mask_prompt = mask_inputs
else:
# Otherwise, simply feed None (and SAM's prompt encoder will add
# a learned `no_mask_embed` to indicate no mask input in this case).
sam_mask_prompt = None
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None,
masks=sam_mask_prompt,
)
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
high_res_features=high_res_features,
)
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
sam_output_token = sam_output_tokens[:, 0]
if multimask_output:
# take the best mask prediction (with the highest IoU estimation)
best_iou_inds = torch.argmax(ious, dim=-1)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# Extract object pointer from the SAM output token (with occlusion handling)
obj_ptr = self.obj_ptr_proj(sam_output_token)
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
# Only hard possible with gt
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float()
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
"""Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float()
high_res_masks = mask_inputs_float * out_scale + out_bias
low_res_masks = F.interpolate(
high_res_masks,
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
# a dummy IoU prediction of all 1's under mask input
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
if not self.use_obj_ptrs_in_encoder:
# all zeros as a dummy object pointer (of shape [B, C])
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
else:
# produce an object pointer using the SAM decoder from the mask input
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
backbone_features=backbone_features,
mask_inputs=self.mask_downsample(mask_inputs_float),
high_res_features=high_res_features,
)
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
# on the object_scores from the SAM decoder.
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
is_obj_appearing = is_obj_appearing[..., None]
lambda_is_obj_appearing = is_obj_appearing.float()
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
if self.pred_obj_scores:
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_masks,
high_res_masks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def forward_image(self, img_batch: torch.Tensor):
"""Processes image batch through encoder to extract multi-level features for SAM model."""
backbone_out = self.image_encoder(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
return backbone_out
def _prepare_backbone_features(self, backbone_out):
"""Prepares and flattens visual features from the image backbone output for further processing."""
backbone_out = backbone_out.copy()
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
def _prepare_memory_conditioned_features(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
):
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
device = current_vision_feats[-1].device
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
# In this case, we skip the fusion with any memory.
if self.num_maskmem == 0: # Disable memory and skip fusion
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
return pix_feat
num_obj_ptr_tokens = 0
# Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame:
# Retrieve the memories encoded with the maskmem backbone
to_cat_memory, to_cat_memory_pos_embed = [], []
# Add conditioning frames's output first (all cond frames have t_pos=0 for
# when getting temporal positional embedding below)
assert len(output_dict["cond_frame_outputs"]) > 0
# Select a maximum number of temporally closest cond frames for cross attention
cond_outputs = output_dict["cond_frame_outputs"]
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
frame_idx, cond_outputs, self.max_cond_frames_in_attn
)
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with r>1), in which case
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
r = self.memory_temporal_stride_for_eval
for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1:
# for t_rel == 1, we take the last frame (regardless of r)
if not track_in_reverse:
# the frame immediately before this frame (i.e. frame_idx - 1)
prev_frame_idx = frame_idx - t_rel
else:
# the frame immediately after this frame (i.e. frame_idx + 1)
prev_frame_idx = frame_idx + t_rel
else:
# for t_rel >= 2, we take the memory frame from every r-th frames
if not track_in_reverse:
# first find the nearest frame among every r-th frames before this frame
# for r=1, this would be (frame_idx - 2)
prev_frame_idx = ((frame_idx - 2) // r) * r
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
else:
# first find the nearest frame among every r-th frames after this frame
# for r=1, this would be (frame_idx + 2)
prev_frame_idx = -(-(frame_idx + 2) // r) * r
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
if out is None:
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
# frames, we still attend to it as if it's a non-conditioning frame.
out = unselected_cond_outputs.get(prev_frame_idx, None)
t_pos_and_prevs.append((t_pos, out))
for t_pos, prev in t_pos_and_prevs:
if prev is None:
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
to_cat_memory_pos_embed.append(maskmem_enc)
# Construct the list of past object pointers
if self.use_obj_ptrs_in_encoder:
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
# First add those object pointers from selected conditioning frames
# (optionally, only include object pointers in the past during evaluation)
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
ptr_cond_outputs = {
t: out
for t, out in selected_cond_outputs.items()
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
}
else:
ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame
(abs(frame_idx - t), out["obj_ptr"])
for t, out in ptr_cond_outputs.items()
]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
for t_diff in range(1, max_obj_ptrs_in_encoder):
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
if t < 0 or (num_frames is not None and t >= num_frames):
break
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
if out is not None:
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
# If we have at least one object pointer, add them to the across attention
if len(pos_and_ptrs) > 0:
pos_list, ptrs_list = zip(*pos_and_ptrs)
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
obj_ptrs = torch.stack(ptrs_list, dim=0)
# a temporal positional embedding based on how far each object pointer is from
# the current frame (sine embedding normalized by the max pointer num).
if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list, device=device)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
else:
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
if self.mem_dim < C:
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
to_cat_memory.append(obj_ptrs)
to_cat_memory_pos_embed.append(obj_pos)
num_obj_ptr_tokens = obj_ptrs.shape[0]
else:
num_obj_ptr_tokens = 0
else:
# for initial conditioning frames, encode them without using any previous memory
if self.directly_add_no_mem_embed:
# directly add no-mem embedding (instead of using the transformer encoder)
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
# Step 2: Concatenate the memories and forward through the transformer encoder
memory = torch.cat(to_cat_memory, dim=0)
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats,
curr_pos=current_vision_pos_embeds,
memory=memory,
memory_pos=memory_pos_embed,
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
# reshape the output (HW)BC => BCHW
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
def _encode_new_memory(
self,
current_vision_feats,
feat_sizes,
pred_masks_high_res,
is_mask_from_pts,
):
"""Encodes frame features and masks into a new memory representation for video segmentation."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
# top-level feature, (HW)BC => BCHW
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
if self.non_overlap_masks_for_mem_enc and not self.training:
# optionally, apply non-overlapping constraints to the masks (it's applied
# in the batch dimension and should only be used during eval, where all
# the objects come from the same video under batch size 1).
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
# scale the raw mask logits with a temperature before applying sigmoid
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
if binarize and not self.training:
mask_for_mem = (pred_masks_high_res > 0).float()
else:
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
mask_for_mem = torch.sigmoid(pred_masks_high_res)
# apply scale and bias terms to the sigmoid probabilities
if self.sigmoid_scale_for_mem_enc != 1.0:
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat,
mask_for_mem,
skip_mask_sigmoid=True, # sigmoid already applied
)
maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
return maskmem_features, maskmem_pos_enc
def track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
# to skip the memory encoder with `run_mem_encoder=False`. For example,
# in demo we might call `track_step` multiple times for each user click,
# and only encode the memory when the user finalizes their clicks. And in ablation
# settings like SAM training on static images, we don't need the memory encoder.
run_mem_encoder=True,
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits=None,
):
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat_with_mem = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
(
_,
_,
_,
low_res_masks,
high_res_masks,
obj_ptr,
_,
) = sam_outputs
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
is_mask_from_pts=(point_inputs is not None),
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
return current_out
def _use_multimask(self, is_init_cond_frame, point_inputs):
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
multimask_output = (
self.multimask_output_in_sam
and (is_init_cond_frame or self.multimask_output_for_tracking)
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
)
return multimask_output
def _apply_non_overlapping_constraints(self, pred_masks):
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
batch_size = pred_masks.size(0)
if batch_size == 1:
return pred_masks
device = pred_masks.device
# "max_obj_inds": object index of the object with the highest score at each location
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
keep = max_obj_inds == batch_obj_inds
# suppress overlapping regions' scores below -10.0 so that the foreground regions
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
return pred_masks

@ -17,16 +17,40 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from ultralytics.nn.modules import LayerNorm2d
from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
"""A sequential container that performs 2D convolution followed by batch normalization."""
"""
A sequential container that performs 2D convolution followed by batch normalization.
Attributes:
c (torch.nn.Conv2d): 2D convolution layer.
1 (torch.nn.BatchNorm2d): Batch normalization layer.
Methods:
__init__: Initializes the Conv2d_BN with specified parameters.
Args:
a (int): Number of input channels.
b (int): Number of output channels.
ks (int): Kernel size for the convolution. Defaults to 1.
stride (int): Stride for the convolution. Defaults to 1.
pad (int): Padding for the convolution. Defaults to 0.
dilation (int): Dilation factor for the convolution. Defaults to 1.
groups (int): Number of groups for the convolution. Defaults to 1.
bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
Examples:
>>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
>>> input_tensor = torch.randn(1, 3, 224, 224)
>>> output = conv_bn(input_tensor)
>>> print(output.shape)
"""
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
drop path.
"""
"""Initializes a sequential container with 2D convolution followed by batch normalization."""
super().__init__()
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
@ -36,12 +60,29 @@ class Conv2d_BN(torch.nn.Sequential):
class PatchEmbed(nn.Module):
"""Embeds images into patches and projects them into a specified embedding dimension."""
"""
Embeds images into patches and projects them into a specified embedding dimension.
Attributes:
patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
num_patches (int): Total number of patches.
in_chans (int): Number of input channels.
embed_dim (int): Dimension of the embedding.
seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
Methods:
forward: Processes the input tensor through the patch embedding sequence.
Examples:
>>> import torch
>>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
>>> x = torch.randn(1, 3, 224, 224)
>>> output = patch_embed(x)
>>> print(output.shape)
"""
def __init__(self, in_chans, embed_dim, resolution, activation):
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
function.
"""
"""Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
@ -56,17 +97,40 @@ class PatchEmbed(nn.Module):
)
def forward(self, x):
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
"""Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
return self.seq(x)
class MBConv(nn.Module):
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
"""
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
Attributes:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
out_chans (int): Number of output channels.
conv1 (Conv2d_BN): First convolutional layer.
act1 (nn.Module): First activation function.
conv2 (Conv2d_BN): Depthwise convolutional layer.
act2 (nn.Module): Second activation function.
conv3 (Conv2d_BN): Final convolutional layer.
act3 (nn.Module): Third activation function.
drop_path (nn.Module): Drop path layer (Identity for inference).
Methods:
forward: Performs the forward pass through the MBConv layer.
Examples:
>>> in_chans, out_chans = 32, 64
>>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
>>> x = torch.randn(1, in_chans, 56, 56)
>>> output = mbconv(x)
>>> print(output.shape)
torch.Size([1, 64, 56, 56])
"""
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
function.
"""
"""Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
@ -86,7 +150,7 @@ class MBConv(nn.Module):
self.drop_path = nn.Identity()
def forward(self, x):
"""Implements the forward pass for the model architecture."""
"""Implements the forward pass of MBConv, applying convolutions and skip connection."""
shortcut = x
x = self.conv1(x)
x = self.act1(x)
@ -99,12 +163,34 @@ class MBConv(nn.Module):
class PatchMerging(nn.Module):
"""Merges neighboring patches in the feature map and projects to a new dimension."""
"""
Merges neighboring patches in the feature map and projects to a new dimension.
This class implements a patch merging operation that combines spatial information and adjusts the feature
dimension. It uses a series of convolutional layers with batch normalization to achieve this.
Attributes:
input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
dim (int): The input dimension of the feature map.
out_dim (int): The output dimension after merging and projection.
act (nn.Module): The activation function used between convolutions.
conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
conv3 (Conv2d_BN): The third convolutional layer for final projection.
Methods:
forward: Applies the patch merging operation to the input tensor.
Examples:
>>> input_resolution = (56, 56)
>>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
>>> x = torch.randn(4, 64, 56, 56)
>>> output = patch_merging(x)
>>> print(output.shape)
"""
def __init__(self, input_resolution, dim, out_dim, activation):
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
optional parameters.
"""
"""Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
super().__init__()
self.input_resolution = input_resolution
@ -117,7 +203,7 @@ class PatchMerging(nn.Module):
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
"""Applies patch merging and dimension projection to the input feature map."""
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
@ -137,7 +223,24 @@ class ConvLayer(nn.Module):
"""
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
This layer optionally applies downsample operations to the output and supports gradient checkpointing.
Attributes:
dim (int): Dimensionality of the input and output.
input_resolution (Tuple[int, int]): Resolution of the input image.
depth (int): Number of MBConv layers in the block.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
blocks (nn.ModuleList): List of MBConv layers.
downsample (Optional[Callable]): Function for downsampling the output.
Methods:
forward: Processes the input through the convolutional layers.
Examples:
>>> input_tensor = torch.randn(1, 64, 56, 56)
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
>>> output = conv_layer(input_tensor)
>>> print(output.shape)
"""
def __init__(
@ -155,16 +258,25 @@ class ConvLayer(nn.Module):
"""
Initializes the ConvLayer with the given dimensions and settings.
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
optionally applies downsampling to the output.
Args:
dim (int): The dimensionality of the input and output.
input_resolution (Tuple[int, int]): The resolution of the input image.
depth (int): The number of MBConv layers in the block.
activation (Callable): Activation function applied after each convolution.
drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
Examples:
>>> input_tensor = torch.randn(1, 64, 56, 56)
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
>>> output = conv_layer(input_tensor)
>>> print(output.shape)
"""
super().__init__()
self.dim = dim
@ -194,7 +306,7 @@ class ConvLayer(nn.Module):
)
def forward(self, x):
"""Processes the input through a series of convolutional layers and returns the activated output."""
"""Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
@ -202,13 +314,33 @@ class ConvLayer(nn.Module):
class Mlp(nn.Module):
"""
Multi-layer Perceptron (MLP) for transformer architectures.
Multi-layer Perceptron (MLP) module for transformer architectures.
This module applies layer normalization, two fully-connected layers with an activation function in between,
and dropout. It is commonly used in transformer-based architectures.
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
Attributes:
norm (nn.LayerNorm): Layer normalization applied to the input.
fc1 (nn.Linear): First fully-connected layer.
fc2 (nn.Linear): Second fully-connected layer.
act (nn.Module): Activation function applied after the first fully-connected layer.
drop (nn.Dropout): Dropout layer applied after the activation function.
Methods:
forward: Applies the MLP operations on the input tensor.
Examples:
>>> import torch
>>> from torch import nn
>>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
>>> x = torch.randn(32, 100, 256)
>>> output = mlp(x)
>>> print(output.shape)
torch.Size([32, 100, 256])
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
"""Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -219,7 +351,7 @@ class Mlp(nn.Module):
self.drop = nn.Dropout(drop)
def forward(self, x):
"""Applies operations on input x and returns modified x, runs downsample if not None."""
"""Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
@ -230,12 +362,37 @@ class Mlp(nn.Module):
class Attention(torch.nn.Module):
"""
Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
grid.
Multi-head attention module with spatial awareness and trainable attention biases.
This module implements a multi-head attention mechanism with support for spatial awareness, applying
attention biases based on spatial resolution. It includes trainable attention biases for each unique
offset between spatial positions in the resolution grid.
Attributes:
ab (Tensor, optional): Cached attention biases for inference, deleted during training.
num_heads (int): Number of attention heads.
scale (float): Scaling factor for attention scores.
key_dim (int): Dimensionality of the keys and queries.
nh_kd (int): Product of num_heads and key_dim.
d (int): Dimensionality of the value vectors.
dh (int): Product of d and num_heads.
attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
norm (nn.LayerNorm): Layer normalization applied to input.
qkv (nn.Linear): Linear layer for computing query, key, and value projections.
proj (nn.Linear): Linear layer for final projection.
attention_biases (nn.Parameter): Learnable attention biases.
attention_bias_idxs (Tensor): Indices for attention biases.
ab (Tensor): Cached attention biases for inference, deleted during training.
Methods:
train: Sets the module in training mode and handles the 'ab' attribute.
forward: Performs the forward pass of the attention mechanism.
Examples:
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
>>> x = torch.randn(1, 196, 256)
>>> output = attn(x)
>>> print(output.shape)
torch.Size([1, 196, 256])
"""
def __init__(
@ -247,17 +404,28 @@ class Attention(torch.nn.Module):
resolution=(14, 14),
):
"""
Initializes the Attention module.
Initializes the Attention module for multi-head attention with spatial awareness.
This module implements a multi-head attention mechanism with support for spatial awareness, applying
attention biases based on spatial resolution. It includes trainable attention biases for each unique
offset between spatial positions in the resolution grid.
Args:
dim (int): The dimensionality of the input and output.
key_dim (int): The dimensionality of the keys and queries.
num_heads (int, optional): Number of attention heads. Default is 8.
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
num_heads (int): Number of attention heads. Default is 8.
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
Raises:
AssertionError: If `resolution` is not a tuple of length 2.
AssertionError: If 'resolution' is not a tuple of length 2.
Examples:
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
>>> x = torch.randn(1, 196, 256)
>>> output = attn(x)
>>> print(output.shape)
torch.Size([1, 196, 256])
"""
super().__init__()
@ -290,7 +458,7 @@ class Attention(torch.nn.Module):
@torch.no_grad()
def train(self, mode=True):
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
"""Performs multi-head attention with spatial awareness and trainable attention biases."""
super().train(mode)
if mode and hasattr(self, "ab"):
del self.ab
@ -298,7 +466,7 @@ class Attention(torch.nn.Module):
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x
"""Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
"""Applies multi-head attention with spatial awareness and trainable attention biases."""
B, N, _ = x.shape # B, N, C
# Normalization
@ -322,7 +490,34 @@ class Attention(torch.nn.Module):
class TinyViTBlock(nn.Module):
"""TinyViT Block that applies self-attention and a local convolution to the input."""
"""
TinyViT Block that applies self-attention and a local convolution to the input.
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
local convolutions to process input features efficiently.
Attributes:
dim (int): The dimensionality of the input and output.
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
num_heads (int): Number of attention heads.
window_size (int): Size of the attention window.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
drop_path (nn.Module): Stochastic depth layer, identity function during inference.
attn (Attention): Self-attention module.
mlp (Mlp): Multi-layer perceptron module.
local_conv (Conv2d_BN): Depth-wise local convolution layer.
Methods:
forward: Processes the input through the TinyViT block.
extra_repr: Returns a string with extra information about the block's parameters.
Examples:
>>> input_tensor = torch.randn(1, 196, 192)
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
>>> output = block(input_tensor)
>>> print(output.shape)
torch.Size([1, 196, 192])
"""
def __init__(
self,
@ -337,22 +532,32 @@ class TinyViTBlock(nn.Module):
activation=nn.GELU,
):
"""
Initializes the TinyViTBlock.
Initializes a TinyViT block with self-attention and local convolution.
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
local convolutions to process input features efficiently.
Args:
dim (int): The dimensionality of the input and output.
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
dim (int): Dimensionality of the input and output features.
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
num_heads (int): Number of attention heads.
window_size (int, optional): Window size for attention. Default is 7.
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
drop (float, optional): Dropout rate. Default is 0.
drop_path (float, optional): Stochastic depth rate. Default is 0.
local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
window_size (int): Size of the attention window. Must be greater than 0.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
drop (float): Dropout rate.
drop_path (float): Stochastic depth rate.
local_conv_size (int): Kernel size of the local convolution.
activation (torch.nn.Module): Activation function for MLP.
Raises:
AssertionError: If `window_size` is not greater than 0.
AssertionError: If `dim` is not divisible by `num_heads`.
AssertionError: If window_size is not greater than 0.
AssertionError: If dim is not divisible by num_heads.
Examples:
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
>>> input_tensor = torch.randn(1, 196, 192)
>>> output = block(input_tensor)
>>> print(output.shape)
torch.Size([1, 196, 192])
"""
super().__init__()
self.dim = dim
@ -380,9 +585,7 @@ class TinyViTBlock(nn.Module):
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
def forward(self, x):
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
convolution.
"""
"""Applies self-attention, local convolution, and MLP operations to the input tensor."""
h, w = self.input_resolution
b, hw, c = x.shape # batch, height*width, channels
assert hw == h * w, "input feature has wrong size"
@ -424,8 +627,19 @@ class TinyViTBlock(nn.Module):
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
attentions heads, window size, and MLP ratio.
"""
Returns a string representation of the TinyViTBlock's parameters.
This method provides a formatted string containing key information about the TinyViTBlock, including its
dimension, input resolution, number of attention heads, window size, and MLP ratio.
Returns:
(str): A formatted string containing the block's parameters.
Examples:
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
>>> print(block.extra_repr())
dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
"""
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
@ -434,7 +648,31 @@ class TinyViTBlock(nn.Module):
class BasicLayer(nn.Module):
"""A basic TinyViT layer for one stage in a TinyViT architecture."""
"""
A basic TinyViT layer for one stage in a TinyViT architecture.
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
and an optional downsampling operation.
Attributes:
dim (int): The dimensionality of the input and output features.
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
depth (int): Number of TinyViT blocks in this layer.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
Methods:
forward: Processes the input through the layer's blocks and optional downsampling.
extra_repr: Returns a string with the layer's parameters for printing.
Examples:
>>> input_tensor = torch.randn(1, 3136, 192)
>>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
>>> output = layer(input_tensor)
>>> print(output.shape)
torch.Size([1, 784, 384])
"""
def __init__(
self,
@ -453,25 +691,34 @@ class BasicLayer(nn.Module):
out_dim=None,
):
"""
Initializes the BasicLayer.
Initializes a BasicLayer in the TinyViT architecture.
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
process feature maps at a specific resolution and dimensionality within the TinyViT model.
Args:
dim (int): The dimensionality of the input and output.
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
depth (int): Number of TinyViT blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
drop (float, optional): Dropout rate. Default is 0.
drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
out_dim (int | None, optional): The output dimension of the layer. Default is None.
dim (int): Dimensionality of the input and output features.
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
depth (int): Number of TinyViT blocks in this layer.
num_heads (int): Number of attention heads in each TinyViT block.
window_size (int): Size of the local window for attention computation.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
drop (float): Dropout rate.
drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
activation (nn.Module): Activation function used in the MLP.
out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
Raises:
ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
ValueError: If `drop_path` is a list and its length doesn't match `depth`.
Examples:
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
>>> x = torch.randn(1, 56*56, 96)
>>> output = layer(x)
>>> print(output.shape)
"""
super().__init__()
self.dim = dim
@ -505,58 +752,49 @@ class BasicLayer(nn.Module):
)
def forward(self, x):
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
"""Processes input through TinyViT blocks and optional downsampling."""
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
def extra_repr(self) -> str:
"""Returns a string representation of the extra_repr function with the layer's parameters."""
"""Returns a string with the layer's parameters for printing."""
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
class LayerNorm2d(nn.Module):
"""A PyTorch implementation of Layer Normalization in 2D."""
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform a forward pass, normalizing the input tensor."""
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight[:, None, None] * x + self.bias[:, None, None]
class TinyViT(nn.Module):
"""
The TinyViT architecture for vision tasks.
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
This class implements the TinyViT model, which combines elements of vision transformers and convolutional
neural networks for improved efficiency and performance on vision tasks.
Attributes:
img_size (int): Input image size.
in_chans (int): Number of input channels.
num_classes (int): Number of classification classes.
embed_dims (List[int]): List of embedding dimensions for each layer.
depths (List[int]): List of depths for each layer.
num_heads (List[int]): List of number of attention heads for each layer.
window_sizes (List[int]): List of window sizes for each layer.
depths (List[int]): Number of blocks in each stage.
num_layers (int): Total number of layers in the network.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
drop_rate (float): Dropout rate for drop layers.
drop_path_rate (float): Drop path rate for stochastic depth.
use_checkpoint (bool): Use checkpointing for efficient memory usage.
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
local_conv_size (int): Local convolution kernel size.
layer_lr_decay (float): Layer-wise learning rate decay.
Note:
This implementation is generalized to accept a list of depths, attention heads,
embedding dimensions and window sizes, which allows you to create a
"stack" of TinyViT models of varying configurations.
patch_embed (PatchEmbed): Module for patch embedding.
patches_resolution (Tuple[int, int]): Resolution of embedded patches.
layers (nn.ModuleList): List of network layers.
norm_head (nn.LayerNorm): Layer normalization for the classifier head.
head (nn.Linear): Linear layer for final classification.
neck (nn.Sequential): Neck module for feature refinement.
Methods:
set_layer_lr_decay: Sets layer-wise learning rate decay.
_init_weights: Initializes weights for linear and normalization layers.
no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
forward_features: Processes input through the feature extraction layers.
forward: Performs a forward pass through the entire network.
Examples:
>>> model = TinyViT(img_size=224, num_classes=1000)
>>> x = torch.randn(1, 3, 224, 224)
>>> features = model.forward_features(x)
>>> print(features.shape)
torch.Size([1, 256, 64, 64])
"""
def __init__(
@ -579,21 +817,33 @@ class TinyViT(nn.Module):
"""
Initializes the TinyViT model.
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
attention and convolution blocks, and a classification head.
Args:
img_size (int, optional): The input image size. Defaults to 224.
in_chans (int, optional): Number of input channels. Defaults to 3.
num_classes (int, optional): Number of classification classes. Defaults to 1000.
embed_dims (List[int], optional): List of embedding dimensions per layer. Defaults to [96, 192, 384, 768].
depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
num_heads (List[int], optional): List of number of attention heads per layer. Defaults to [3, 6, 12, 24].
window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
drop_rate (float, optional): Dropout rate. Defaults to 0.
drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
img_size (int): Size of the input image. Default is 224.
in_chans (int): Number of input channels. Default is 3.
num_classes (int): Number of classes for classification. Default is 1000.
embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
Default is (96, 192, 384, 768).
depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
Default is (3, 6, 12, 24).
window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
drop_rate (float): Dropout rate. Default is 0.0.
drop_path_rate (float): Stochastic depth rate. Default is 0.1.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
local_conv_size (int): Kernel size for local convolutions. Default is 3.
layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
Examples:
>>> model = TinyViT(img_size=224, num_classes=1000)
>>> x = torch.randn(1, 3, 224, 224)
>>> output = model(x)
>>> print(output.shape)
torch.Size([1, 1000])
"""
super().__init__()
self.img_size = img_size
@ -671,7 +921,7 @@ class TinyViT(nn.Module):
)
def set_layer_lr_decay(self, layer_lr_decay):
"""Sets the learning rate decay for each layer in the TinyViT model."""
"""Sets layer-wise learning rate decay for the TinyViT model based on depth."""
decay_rate = layer_lr_decay
# Layers -> blocks (depth)
@ -706,7 +956,7 @@ class TinyViT(nn.Module):
self.apply(_check_lr_scale)
def _init_weights(self, m):
"""Initializes weights for linear layers and layer normalization in the given module."""
"""Initializes weights for linear and normalization layers in the TinyViT model."""
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
@ -718,11 +968,11 @@ class TinyViT(nn.Module):
@torch.jit.ignore
def no_weight_decay_keywords(self):
"""Returns a dictionary of parameter names where weight decay should not be applied."""
"""Returns a set of keywords for parameters that should not use weight decay."""
return {"attention_biases"}
def forward_features(self, x):
"""Runs the input through the model layers and returns the transformed output."""
"""Processes input through feature extraction layers, returning spatial features."""
x = self.patch_embed(x) # x input is (N, C, H, W)
x = self.layers[0](x)
@ -737,5 +987,5 @@ class TinyViT(nn.Module):
return self.neck(x)
def forward(self, x):
"""Executes a forward pass on the input tensor through the constructed model layers."""
"""Performs the forward pass through the TinyViT model, extracting features from the input image."""
return self.forward_features(x)

@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
class TwoWayTransformer(nn.Module):
"""
A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
processing.
A Two-Way Transformer module for simultaneous attention to image and query points.
This class implements a specialized transformer decoder that attends to an input image using queries with
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
cloud processing.
Attributes:
depth (int): The number of layers in the transformer.
embedding_dim (int): The channel dimension for the input embeddings.
num_heads (int): The number of heads for multihead attention.
mlp_dim (int): The internal channel dimension for the MLP block.
layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention.
mlp_dim (int): Internal channel dimension for the MLP block.
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
final_attn_token_to_image (Attention): Final attention layer from queries to image.
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
Methods:
forward: Processes image and point embeddings through the transformer.
Examples:
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)
"""
def __init__(
@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module):
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
mlp_dim (int): Internal channel dimension for the MLP block.
activation (Type[nn.Module]): Activation function to use in the MLP block.
attention_downsample_rate (int): Downsampling rate for attention mechanism.
Attributes:
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention.
mlp_dim (int): Internal channel dimension for the MLP block.
layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
final_attn_token_to_image (Attention): Final attention layer from queries to image.
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
Examples:
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)
"""
super().__init__()
self.depth = depth
@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module):
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Processes image and point embeddings through the Two-Way Transformer.
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
Returns:
(torch.Tensor): the processed point_embedding
(torch.Tensor): the processed image_embedding
(Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
Examples:
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module):
class TwoWayAttentionBlock(nn.Module):
"""
An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
sparse inputs.
A two-way attention block for simultaneous attention to image and query points.
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
inputs to sparse inputs.
Attributes:
self_attn (Attention): The self-attention layer for the queries.
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
self_attn (Attention): Self-attention layer for queries.
norm1 (nn.LayerNorm): Layer normalization after self-attention.
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
mlp (MLPBlock): MLP block that transforms the query embeddings.
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
mlp (MLPBlock): MLP block for transforming query embeddings.
norm3 (nn.LayerNorm): Layer normalization after MLP block.
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
Methods:
forward: Applies self-attention and cross-attention to queries and keys.
Examples:
>>> embedding_dim, num_heads = 256, 8
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
>>> queries = torch.randn(1, 100, embedding_dim)
>>> keys = torch.randn(1, 1000, embedding_dim)
>>> query_pe = torch.randn(1, 100, embedding_dim)
>>> key_pe = torch.randn(1, 1000, embedding_dim)
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
"""
def __init__(
@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module):
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
This block implements a specialized transformer layer with four main components: self-attention on sparse
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
of dense inputs to sparse inputs.
Args:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
embedding_dim (int): Channel dimension of the embeddings.
num_heads (int): Number of attention heads in the attention layers.
mlp_dim (int): Hidden dimension of the MLP block.
activation (Type[nn.Module]): Activation function for the MLP block.
attention_downsample_rate (int): Downsampling rate for the attention mechanism.
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
Examples:
>>> embedding_dim, num_heads = 256, 8
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
>>> queries = torch.randn(1, 100, embedding_dim)
>>> keys = torch.randn(1, 1000, embedding_dim)
>>> query_pe = torch.randn(1, 100, embedding_dim)
>>> key_pe = torch.randn(1, 1000, embedding_dim)
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module):
self.skip_first_layer_pe = skip_first_layer_pe
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
"""Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
"""Applies two-way attention to process query and key embeddings in a transformer block."""
# Self attention block
if self.skip_first_layer_pe:
@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module):
class Attention(nn.Module):
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
An attention layer with downscaling capability for embedding size after projection.
This class implements a multi-head attention mechanism with the option to downsample the internal
dimension of queries, keys, and values.
Attributes:
embedding_dim (int): Dimensionality of input embeddings.
kv_in_dim (int): Dimensionality of key and value inputs.
internal_dim (int): Internal dimension after downsampling.
num_heads (int): Number of attention heads.
q_proj (nn.Linear): Linear projection for queries.
k_proj (nn.Linear): Linear projection for keys.
v_proj (nn.Linear): Linear projection for values.
out_proj (nn.Linear): Linear projection for output.
Methods:
_separate_heads: Separates input tensor into attention heads.
_recombine_heads: Recombines separated attention heads.
forward: Computes attention output for given query, key, and value tensors.
Examples:
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
>>> q = torch.randn(1, 100, 256)
>>> k = v = torch.randn(1, 50, 256)
>>> output = attn(q, k, v)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
def __init__(
@ -214,15 +303,27 @@ class Attention(nn.Module):
kv_in_dim: int = None,
) -> None:
"""
Initializes the Attention model with the given dimensions and settings.
Initializes the Attention module with specified dimensions and settings.
This class implements a multi-head attention mechanism with optional downsampling of the internal
dimension for queries, keys, and values.
Args:
embedding_dim (int): The dimensionality of the input embeddings.
num_heads (int): The number of attention heads.
downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
embedding_dim (int): Dimensionality of input embeddings.
num_heads (int): Number of attention heads.
downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
Raises:
AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate).
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
Examples:
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
>>> q = torch.randn(1, 100, 256)
>>> k = v = torch.randn(1, 50, 256)
>>> output = attn(q, k, v)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
super().__init__()
self.embedding_dim = embedding_dim
@ -238,20 +339,20 @@ class Attention(nn.Module):
@staticmethod
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
"""Separate the input tensor into the specified number of attention heads."""
"""Separates the input tensor into the specified number of attention heads."""
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
@staticmethod
def _recombine_heads(x: Tensor) -> Tensor:
"""Recombine the separated attention heads into a single tensor."""
"""Recombines separated attention heads into a single tensor."""
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Compute the attention output given the input query, key, and value tensors."""
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
# Input projections
q = self.q_proj(q)

@ -1,5 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import Tuple
import torch
import torch.nn.functional as F
@ -70,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
def init_t_xy(end_x: int, end_y: int):
"""Initializes 1D and 2D coordinate tensors for a grid of size end_x by end_y."""
"""Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
@ -78,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int):
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
"""Computes axial complex exponential positional encodings for 2D spatial positions."""
"""Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
@ -91,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""Reshapes frequency tensor for broadcasting, ensuring compatibility with input tensor dimensions."""
"""Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
@ -189,3 +191,103 @@ def window_unpartition(windows, window_size, pad_hw, hw):
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Extracts relative positional embeddings based on query and key sizes.
Args:
q_size (int): Size of the query.
k_size (int): Size of the key.
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
distance and C is the embedding dimension.
Returns:
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
k_size, C).
Examples:
>>> q_size, k_size = 8, 16
>>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
>>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
>>> print(extracted_pos.shape)
torch.Size([8, 16, 64])
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Adds decomposed Relative Positional Embeddings to the attention map.
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
positions.
Args:
attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
Returns:
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
(B, q_h * q_w, k_h * k_w).
Examples:
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
>>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
>>> q = torch.rand(B, q_h * q_w, C)
>>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
>>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
>>> q_size, k_size = (q_h, q_w), (k_h, k_w)
>>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
>>> print(updated_attn.shape)
torch.Size([1, 64, 64])
References:
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
B, q_h * q_w, k_h * k_w
)
return attn

@ -34,35 +34,64 @@ from .build import build_sam
class Predictor(BasePredictor):
"""
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
The class provides an interface for model inference tailored to image segmentation tasks.
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
points, and low-resolution masks.
This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
fine-grained control over segmentation results.
Attributes:
cfg (dict): Configuration dictionary specifying model and task-related parameters.
overrides (dict): Dictionary containing values that override the default configuration.
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
args (namespace): Namespace to hold command-line arguments or other operational variables.
im (torch.Tensor): Preprocessed input image tensor.
features (torch.Tensor): Extracted image features used for inference.
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
args (SimpleNamespace): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded SAM model.
device (torch.device): The device (CPU or GPU) on which the model is loaded.
im (torch.Tensor): The preprocessed input image.
features (torch.Tensor): Extracted image features.
prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
segment_all (bool): Flag to indicate if full image segmentation should be performed.
mean (torch.Tensor): Mean values for image normalization.
std (torch.Tensor): Standard deviation values for image normalization.
Methods:
preprocess: Prepares input images for model inference.
pre_transform: Performs initial transformations on the input image.
inference: Performs segmentation inference based on input prompts.
prompt_inference: Internal function for prompt-based segmentation inference.
generate: Generates segmentation masks for an entire image.
setup_model: Initializes the SAM model for inference.
get_model: Builds and returns a SAM model.
postprocess: Post-processes model outputs to generate final results.
setup_source: Sets up the data source for inference.
set_image: Sets and preprocesses a single image for inference.
get_im_features: Extracts image features using the SAM image encoder.
set_prompts: Sets prompts for subsequent inference.
reset_image: Resets the current image and its features.
remove_small_regions: Removes small disconnected regions and holes from masks.
Examples:
>>> predictor = Predictor()
>>> predictor.setup_model(model_path='sam_model.pt')
>>> predictor.set_image('image.jpg')
>>> masks, scores, boxes = predictor.generate()
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the Predictor with configuration, overrides, and callbacks.
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
for optimal results.
Args:
cfg (dict): Configuration dictionary.
overrides (dict, optional): Dictionary of values to override default configuration.
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
cfg (Dict): Configuration dictionary containing default settings.
overrides (Dict | None): Dictionary of values to override default configuration.
_callbacks (Dict | None): Dictionary of callback functions to customize behavior.
Examples:
>>> predictor = Predictor(cfg=DEFAULT_CFG)
>>> predictor = Predictor(overrides={'imgsz': 640})
>>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
"""
if overrides is None:
overrides = {}
@ -78,14 +107,19 @@ class Predictor(BasePredictor):
"""
Preprocess the input image for model inference.
The method prepares the input image by applying transformations and normalization.
It supports both torch.Tensor and list of np.ndarray as input formats.
This method prepares the input image by applying transformations and normalization. It supports both
torch.Tensor and list of np.ndarray as input formats.
Args:
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
Returns:
(torch.Tensor): The preprocessed image tensor.
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
Examples:
>>> predictor = Predictor()
>>> image = torch.rand(1, 3, 640, 640)
>>> preprocessed_image = predictor.preprocess(image)
"""
if self.im is not None:
return self.im
@ -106,14 +140,24 @@ class Predictor(BasePredictor):
"""
Perform initial transformations on the input image for preprocessing.
The method applies transformations such as resizing to prepare the image for further preprocessing.
This method applies transformations such as resizing to prepare the image for further preprocessing.
Currently, batched inference is not supported; hence the list length should be 1.
Args:
im (List[np.ndarray]): List containing images in HWC numpy array format.
im (List[np.ndarray]): List containing a single image in HWC numpy array format.
Returns:
(List[np.ndarray]): List of transformed images.
(List[np.ndarray]): List containing the transformed image.
Raises:
AssertionError: If the input list contains more than one image.
Examples:
>>> predictor = Predictor()
>>> image = np.random.rand(480, 640, 3) # Single HWC image
>>> transformed = predictor.pre_transform([image])
>>> print(len(transformed))
1
"""
assert len(im) == 1, "SAM model does not currently support batched inference"
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
@ -121,23 +165,32 @@ class Predictor(BasePredictor):
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
"""
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
mask decoder for real-time and promptable segmentation tasks.
Perform image segmentation inference based on the given input cues, using the currently loaded image.
This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
encoder, and mask decoder for real-time and promptable segmentation tasks.
Args:
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
*args (Any): Additional positional arguments.
**kwargs (Any): Additional keyword arguments.
Returns:
(tuple): Contains the following three elements.
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
(tuple): Contains the following three elements:
- np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
- np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
Examples:
>>> predictor = Predictor()
>>> predictor.setup_model(model_path='sam_model.pt')
>>> predictor.set_image('image.jpg')
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
"""
# Override prompts if any stored in self.prompts
bboxes = self.prompts.pop("bboxes", bboxes)
@ -151,22 +204,30 @@ class Predictor(BasePredictor):
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
"""
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
Performs image segmentation inference based on input cues using SAM's specialized architecture.
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
Args:
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
Returns:
(tuple): Contains the following three elements.
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
(tuple): Tuple containing:
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
- np.ndarray: Quality scores predicted by the model for each mask, with length C.
- np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
Examples:
>>> predictor = Predictor()
>>> im = torch.rand(1, 3, 1024, 1024)
>>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
"""
features = self.get_im_features(im) if self.features is None else self.features
@ -224,27 +285,32 @@ class Predictor(BasePredictor):
"""
Perform image segmentation using the Segment Anything Model (SAM).
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
Args:
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
Each layer produces 2**i_layer number of image crops.
crop_overlap_ratio (float): Determines the overlap between crops. Scaled down in subsequent layers.
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
Used in the nth crop layer.
points_stride (int, optional): Number of points to sample along each side of the image.
Exclusive with 'point_grids'.
im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
crop_n_layers (int): Number of layers for additional mask predictions on image crops.
crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
points_stride (int): Number of points to sample along each side of the image.
points_batch_size (int): Batch size for the number of points processed simultaneously.
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
stability_score_offset (float): Offset value for calculating stability score.
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
Returns:
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
- pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
- pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
Examples:
>>> predictor = Predictor()
>>> im = torch.rand(1, 3, 1024, 1024) # Example input image
>>> masks, scores, boxes = predictor.generate(im)
"""
import torchvision # scope for faster 'import ultralytics'
@ -326,11 +392,9 @@ class Predictor(BasePredictor):
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
verbose (bool): If True, prints selected device information.
Attributes:
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
device (torch.device): The device to which the model and tensors are allocated.
mean (torch.Tensor): The mean values for image normalization.
std (torch.Tensor): The standard deviation values for image normalization.
Examples:
>>> predictor = Predictor()
>>> predictor.setup_model(model=sam_model, verbose=True)
"""
device = select_device(self.args.device, verbose=verbose)
if model is None:
@ -349,23 +413,32 @@ class Predictor(BasePredictor):
self.done_warmup = True
def get_model(self):
"""Built Segment Anything Model (SAM) model."""
"""Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
return build_sam(self.args.model)
def postprocess(self, preds, img, orig_imgs):
"""
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions.
The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
This method scales masks and boxes to the original image size and applies a threshold to the mask
predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
Args:
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
img (torch.Tensor): The processed input image tensor.
orig_imgs (list | torch.Tensor): The original, unprocessed images.
preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
- pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
- pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
img (torch.Tensor): The processed input image tensor with shape (C, H, W).
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
Returns:
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
(List[Results]): List of Results objects containing detection masks, bounding boxes, and other
metadata for each processed image.
Examples:
>>> predictor = Predictor()
>>> preds = predictor.inference(img)
>>> results = predictor.postprocess(preds, img, orig_imgs)
"""
# (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2]
@ -393,11 +466,23 @@ class Predictor(BasePredictor):
"""
Sets up the data source for inference.
This method configures the data source from which images will be fetched for inference. The source could be a
directory, a video file, or other types of image data sources.
This method configures the data source from which images will be fetched for inference. It supports
various input types such as image files, directories, video files, and other compatible data sources.
Args:
source (str | Path): The path to the image data source for inference.
source (str | Path | None): The path or identifier for the image data source. Can be a file path,
directory path, URL, or other supported source types.
Examples:
>>> predictor = Predictor()
>>> predictor.setup_source('path/to/images')
>>> predictor.setup_source('video.mp4')
>>> predictor.setup_source(None) # Uses default source if available
Notes:
- If source is None, the method may use a default source if configured.
- The method adapts to different source types and prepares them for subsequent inference steps.
- Supported source types may include local files, directories, URLs, and video streams.
"""
if source is not None:
super().setup_source(source)
@ -406,14 +491,25 @@ class Predictor(BasePredictor):
"""
Preprocesses and sets a single image for inference.
This function sets up the model if not already initialized, configures the data source to the specified image,
and preprocesses the image for feature extraction. Only one image can be set at a time.
This method prepares the model for inference on a single image by setting up the model if not already
initialized, configuring the data source, and preprocessing the image for feature extraction. It
ensures that only one image is set at a time and extracts image features for subsequent use.
Args:
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
an image read by cv2.
Raises:
AssertionError: If more than one image is set.
AssertionError: If more than one image is attempted to be set.
Examples:
>>> predictor = Predictor()
>>> predictor.set_image('path/to/image.jpg')
>>> predictor.set_image(cv2.imread('path/to/image.jpg'))
Notes:
- This method should be called before performing inference on a new image.
- The extracted features are stored in the `self.features` attribute for later use.
"""
if self.model is None:
self.setup_model(model=None)
@ -425,35 +521,44 @@ class Predictor(BasePredictor):
break
def get_im_features(self, im):
"""Get image features from the SAM image encoder."""
"""Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
return self.model.image_encoder(im)
def set_prompts(self, prompts):
"""Set prompts in advance."""
"""Sets prompts for subsequent inference operations."""
self.prompts = prompts
def reset_image(self):
"""Resets the image and its features to None."""
"""Resets the current image and its features, clearing them for subsequent inference."""
self.im = None
self.features = None
@staticmethod
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
"""
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
Remove small disconnected regions and holes from segmentation masks.
This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
Suppression (NMS) to eliminate any newly created duplicate boxes.
Args:
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
the number of masks, H is height, and W is width.
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
masks, H is height, and W is width.
min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
this will be removed.
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
Returns:
(tuple([torch.Tensor, List[int]])):
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
(tuple):
- new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
Examples:
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
>>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
>>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
>>> print(f"Indices of kept masks: {keep}")
"""
import torchvision # scope for faster 'import ultralytics'
@ -480,3 +585,188 @@ class Predictor(BasePredictor):
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
class SAM2Predictor(Predictor):
"""
SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
This class extends the base Predictor class to implement SAM2-specific functionality for image
segmentation tasks. It provides methods for model initialization, feature extraction, and
prompt-based inference.
Attributes:
_bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
model (torch.nn.Module): The loaded SAM2 model.
device (torch.device): The device (CPU or GPU) on which the model is loaded.
features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
segment_all (bool): Flag to indicate if all segments should be predicted.
prompts (Dict): Dictionary to store various types of prompts for inference.
Methods:
get_model: Retrieves and initializes the SAM2 model.
prompt_inference: Performs image segmentation inference based on various prompts.
set_image: Preprocesses and sets a single image for inference.
get_im_features: Extracts and processes image features using SAM2's image encoder.
Examples:
>>> predictor = SAM2Predictor(cfg)
>>> predictor.set_image("path/to/image.jpg")
>>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes)
>>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}")
"""
_bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]
def get_model(self):
"""Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
return build_sam(self.args.model)
def prompt_inference(
self,
im,
bboxes=None,
points=None,
labels=None,
masks=None,
multimask_output=False,
img_idx=-1,
):
"""
Performs image segmentation inference based on various prompts using SAM2 architecture.
This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
multi-object prediction scenarios.
Args:
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
img_idx (int): Index of the image in the batch to process.
Returns:
(tuple): Tuple containing:
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
- np.ndarray: Quality scores for each mask, with length C.
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
Examples:
>>> predictor = SAM2Predictor(cfg)
>>> image = torch.rand(1, 3, 640, 640)
>>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
>>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}")
Notes:
- The method supports batched inference for multiple objects when points or bboxes are provided.
- Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
- When both bboxes and points are provided, they are merged into a single 'points' input for the model.
References:
- SAM2 Paper: [Add link to SAM2 paper when available]
"""
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = torch.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes = bboxes.view(-1, 2, 2) * r
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
points = (points, labels) if points is not None else None
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=points,
boxes=None,
masks=masks,
)
# Predict masks
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=batched_mode,
high_res_features=high_res_features,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def set_image(self, image):
"""
Preprocesses and sets a single image for inference using the SAM2 model.
This method initializes the model if not already done, configures the data source to the specified image,
and preprocesses the image for feature extraction. It supports setting only one image at a time.
Args:
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
Raises:
AssertionError: If more than one image is attempted to be set.
Examples:
>>> predictor = SAM2Predictor()
>>> predictor.set_image("path/to/image.jpg")
>>> predictor.set_image(np.array([...])) # Using a numpy array
Notes:
- This method must be called before performing any inference on a new image.
- The method caches the extracted features for efficient subsequent inferences on the same image.
- Only one image can be set at a time. To process multiple images, call this method for each new image.
"""
if self.model is None:
self.setup_model(model=None)
self.setup_source(image)
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.get_im_features(im)
break
def get_im_features(self, im):
"""Extracts image features from the SAM image encoder for subsequent processing."""
backbone_out = self.model.forward_image(im)
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}

@ -1,6 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import SAM2
from .predict import SAM2Predictor
__all__ = "SAM2", "SAM2Predictor" # tuple or list

@ -1,156 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.utils.downloads import attempt_download_asset
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
from .modules.sam2 import SAM2Model
def build_sam2_t(checkpoint=None):
"""Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=96,
encoder_stages=[1, 2, 7, 2],
encoder_num_heads=1,
encoder_global_att_blocks=[5, 7, 9],
encoder_window_spec=[8, 4, 14, 7],
encoder_backbone_channel_list=[768, 384, 192, 96],
checkpoint=checkpoint,
)
def build_sam2_s(checkpoint=None):
"""Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=96,
encoder_stages=[1, 2, 11, 2],
encoder_num_heads=1,
encoder_global_att_blocks=[7, 10, 13],
encoder_window_spec=[8, 4, 14, 7],
encoder_backbone_channel_list=[768, 384, 192, 96],
checkpoint=checkpoint,
)
def build_sam2_b(checkpoint=None):
"""Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=112,
encoder_stages=[2, 3, 16, 3],
encoder_num_heads=2,
encoder_global_att_blocks=[12, 16, 20],
encoder_window_spec=[8, 4, 14, 7],
encoder_window_spatial_size=[14, 14],
encoder_backbone_channel_list=[896, 448, 224, 112],
checkpoint=checkpoint,
)
def build_sam2_l(checkpoint=None):
"""Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=144,
encoder_stages=[2, 6, 36, 4],
encoder_num_heads=2,
encoder_global_att_blocks=[23, 33, 43],
encoder_window_spec=[8, 4, 16, 8],
encoder_backbone_channel_list=[1152, 576, 288, 144],
checkpoint=checkpoint,
)
def _build_sam2(
encoder_embed_dim=1280,
encoder_stages=[2, 6, 36, 4],
encoder_num_heads=2,
encoder_global_att_blocks=[7, 15, 23, 31],
encoder_backbone_channel_list=[1152, 576, 288, 144],
encoder_window_spatial_size=[7, 7],
encoder_window_spec=[8, 4, 16, 8],
checkpoint=None,
):
"""Builds a SAM2 model with specified architecture parameters and optional checkpoint loading."""
image_encoder = ImageEncoder(
trunk=Hiera(
embed_dim=encoder_embed_dim,
num_heads=encoder_num_heads,
stages=encoder_stages,
global_att_blocks=encoder_global_att_blocks,
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
window_spec=encoder_window_spec,
),
neck=FpnNeck(
d_model=256,
backbone_channel_list=encoder_backbone_channel_list,
fpn_top_down_levels=[2, 3],
fpn_interp_model="nearest",
),
scalp=1,
)
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
memory_encoder = MemoryEncoder(out_dim=64)
sam2 = SAM2Model(
image_encoder=image_encoder,
memory_attention=memory_attention,
memory_encoder=memory_encoder,
num_maskmem=7,
image_size=1024,
sigmoid_scale_for_mem_enc=20.0,
sigmoid_bias_for_mem_enc=-10.0,
use_mask_input_as_output_without_sam=True,
directly_add_no_mem_embed=True,
use_high_res_features_in_sam=True,
multimask_output_in_sam=True,
iou_prediction_use_sigmoid=True,
use_obj_ptrs_in_encoder=True,
add_tpos_enc_to_obj_ptrs=True,
only_obj_ptrs_in_the_past_for_eval=True,
pred_obj_scores=True,
pred_obj_scores_mlp=True,
fixed_no_obj_ptr=True,
multimask_output_for_tracking=True,
use_multimask_token_for_obj_ptr=True,
multimask_min_pt_num=0,
multimask_max_pt_num=1,
use_mlp_for_obj_ptr_proj=True,
compile_image_encoder=False,
sam_mask_decoder_extra_args=dict(
dynamic_multimask_via_stability=True,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
),
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)["model"]
sam2.load_state_dict(state_dict)
sam2.eval()
return sam2
sam_model_map = {
"sam2_t.pt": build_sam2_t,
"sam2_s.pt": build_sam2_s,
"sam2_b.pt": build_sam2_b,
"sam2_l.pt": build_sam2_l,
}
def build_sam2(ckpt="sam_b.pt"):
"""Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options."""
model_builder = None
ckpt = str(ckpt) # to allow Path ckpt types
for k in sam_model_map.keys():
if ckpt.endswith(k):
model_builder = sam_model_map.get(k)
if not model_builder:
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
return model_builder(ckpt)

@ -1,97 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
SAM2 model interface.
This module provides an interface to the Segment Anything Model (SAM2) from Ultralytics, designed for real-time image
segmentation tasks. The SAM2 model allows for promptable segmentation with unparalleled versatility in image analysis,
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
image distributions and tasks without prior knowledge.
Key Features:
- Promptable segmentation
- Real-time performance
- Zero-shot transfer capabilities
- Trained on SA-1B dataset
"""
from ultralytics.models.sam import SAM
from .build import build_sam2
from .predict import SAM2Predictor
class SAM2(SAM):
"""
SAM2 class for real-time image segmentation using the Segment Anything Model (SAM2).
This class extends the SAM base class, providing an interface to the SAM2 model for promptable segmentation
tasks. It supports loading pre-trained weights and offers zero-shot performance capabilities.
Attributes:
model (torch.nn.Module): The loaded SAM2 model.
task_map (Dict[str, Type[SAM2Predictor]]): Mapping of 'segment' task to SAM2Predictor.
Methods:
__init__: Initializes the SAM2 model with pre-trained weights.
_load: Loads specified weights into the SAM2 model.
Examples:
>>> sam2 = SAM2("sam2_b.pt")
>>> sam2._load('path/to/sam2_weights.pt')
>>> task_map = sam2.task_map
>>> print(task_map)
{'segment': SAM2Predictor}
Notes:
- Supports .pt and .pth file extensions for model weights.
- Offers zero-shot transfer capabilities for new image distributions and tasks.
"""
def __init__(self, model="sam2_b.pt") -> None:
"""
Initializes the SAM2 model with a pre-trained model file.
Args:
model (str): Path to the pre-trained SAM2 model file. File should have a .pt or .pth extension.
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
Examples:
>>> sam2 = SAM2("sam2_b.pt")
"""
super().__init__(model=model)
def _load(self, weights: str, task=None):
"""
Loads the specified weights into the SAM2 model.
This method is responsible for loading pre-trained weights into the SAM2 model. It supports loading
weights from files with .pt or .pth extensions.
Args:
weights (str): Path to the weights file. Should be a file with .pt or .pth extension.
task (str | None): Task name. If provided, it may be used to configure model-specific settings.
Examples:
>>> sam2_model = SAM2()
>>> sam2_model._load('path/to/sam2_weights.pt')
"""
self.model = build_sam2(weights)
@property
def task_map(self):
"""
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns:
(Dict[str, Type[SAM2Predictor]]): A dictionary mapping the 'segment' task to its corresponding
SAM2Predictor class.
Examples:
>>> sam2 = SAM2()
>>> task_map = sam2.task_map
>>> print(task_map)
{'segment': SAM2Predictor}
"""
return {"segment": {"predictor": SAM2Predictor}}

@ -1 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

@ -1,305 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List, Optional, Tuple, Type
import torch
from torch import nn
from ultralytics.nn.modules import MLP, LayerNorm2d
class MaskDecoder(nn.Module):
"""Transformer-based decoder predicting instance segmentation masks from image and prompt embeddings."""
def __init__(
self,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
use_high_res_features: bool = False,
iou_prediction_use_sigmoid=False,
dynamic_multimask_via_stability=False,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
use_multimask_token_for_obj_ptr: bool = False,
) -> None:
"""
Initializes the MaskDecoder module for predicting instance segmentation masks.
Args:
transformer_dim (int): Channel dimension of the transformer.
transformer (nn.Module): Transformer used to predict masks.
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
iou_head_depth (int): Depth of the MLP used to predict mask quality.
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
use_high_res_features (bool): Whether to use high-resolution features.
iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
pred_obj_scores (bool): Whether to predict object scores.
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
Attributes:
transformer_dim (int): Channel dimension of the transformer.
transformer (nn.Module): Transformer used to predict masks.
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
iou_token (nn.Embedding): Embedding for IOU token.
num_mask_tokens (int): Total number of mask tokens.
mask_tokens (nn.Embedding): Embedding for mask tokens.
pred_obj_scores (bool): Whether to predict object scores.
obj_score_token (nn.Embedding): Embedding for object score token.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
output_upscaling (nn.Sequential): Upscaling layers for output.
use_high_res_features (bool): Whether to use high-resolution features.
conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
iou_prediction_head (MLP): MLP for IOU prediction.
pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.pred_obj_scores = pred_obj_scores
if self.pred_obj_scores:
self.obj_score_token = nn.Embedding(1, transformer_dim)
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.use_high_res_features = use_high_res_features
if use_high_res_features:
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
self.output_hypernetworks_mlps = nn.ModuleList(
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
)
self.iou_prediction_head = MLP(
transformer_dim,
iou_head_hidden_dim,
self.num_mask_tokens,
iou_head_depth,
sigmoid=iou_prediction_use_sigmoid,
)
if self.pred_obj_scores:
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
if pred_obj_scores_mlp:
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
# When outputting a single mask, optionally we can dynamically fall back to the best
# multimask output token if the single mask output token gives low stability scores.
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predicts masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): Embeddings from the image encoder.
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
multimask_output (bool): Whether to return multiple masks or a single mask.
repeat_image (bool): Flag to repeat the image embeddings.
high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
Returns:
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
- masks (torch.Tensor): Batched predicted masks.
- iou_pred (torch.Tensor): Batched predictions of mask quality.
- sam_tokens_out (torch.Tensor): Batched SAM token for mask output.
Examples:
>>> image_embeddings = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out = decoder.forward(image_embeddings, image_pe,
... sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
"""
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
repeat_image=repeat_image,
high_res_features=high_res_features,
)
# Select the correct mask or masks for output
if multimask_output:
masks = masks[:, 1:, :, :]
iou_pred = iou_pred[:, 1:]
elif self.dynamic_multimask_via_stability and not self.training:
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
masks = masks[:, 0:1, :, :]
iou_pred = iou_pred[:, 0:1]
if multimask_output and self.use_multimask_token_for_obj_ptr:
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
else:
# Take the mask output token. Here we *always* use the token for single mask output.
# At test time, even if we track after 1-click (and using multimask_output=True),
# we still take the single mask token here. The rationale is that we always track
# after multiple clicks during training, so the past tokens seen during training
# are always the single mask token (and we'll let it be the object-memory token).
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
# Prepare output
return masks, iou_pred, sam_tokens_out, object_score_logits
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts instance segmentation masks from image and prompt embeddings using a transformer architecture."""
# Concatenate output tokens
s = 0
if self.pred_obj_scores:
output_tokens = torch.cat(
[
self.obj_score_token.weight,
self.iou_token.weight,
self.mask_tokens.weight,
],
dim=0,
)
s = 1
else:
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
if repeat_image:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
assert image_embeddings.shape[0] == tokens.shape[0]
src = image_embeddings
src = src + dense_prompt_embeddings
assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, s, :]
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
if not self.use_high_res_features:
upscaled_embedding = self.output_upscaling(src)
else:
dc1, ln1, act1, dc2, act2 = self.output_upscaling
feat_s0, feat_s1 = high_res_features
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
if self.pred_obj_scores:
assert s == 1
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
else:
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
return masks, iou_pred, mask_tokens_out, object_score_logits
def _get_stability_scores(self, mask_logits):
"""Computes mask stability scores based on IoU between upper and lower thresholds."""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
return stability_scores
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
Dynamically selects the most stable mask output based on stability scores and IoU predictions.
When outputting a single mask, if the stability score from the current single-mask output (based on output token
0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with
the highest predicted IoU score.
This is intended to ensure a valid mask for both clicking and tracking.
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
best_multimask_logits = best_multimask_logits.unsqueeze(1)
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out

@ -1,332 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.models.sam.modules.encoders import PatchEmbed
from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine
class MemoryEncoder(nn.Module):
"""Encodes pixel features and masks into a memory representation for efficient image segmentation."""
def __init__(
self,
out_dim,
in_dim=256, # in_dim of pix_feats
):
"""Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models."""
super().__init__()
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(
self,
pix_feat: torch.Tensor,
masks: torch.Tensor,
skip_mask_sigmoid: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Processes pixel features and masks, fusing them to generate encoded memory representations."""
if not skip_mask_sigmoid:
masks = F.sigmoid(masks)
masks = self.mask_downsampler(masks)
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
return {"vision_features": x, "vision_pos_enc": [pos]}
class ImageEncoder(nn.Module):
"""Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings."""
def __init__(
self,
trunk: nn.Module,
neck: nn.Module,
scalp: int = 0,
):
"""Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction."""
super().__init__()
self.trunk = trunk
self.neck = neck
self.scalp = scalp
assert (
self.trunk.channel_list == self.neck.backbone_channel_list
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
def forward(self, sample: torch.Tensor):
"""Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs."""
features, pos = self.neck(self.trunk(sample))
if self.scalp > 0:
# Discard the lowest resolution features
features, pos = features[: -self.scalp], pos[: -self.scalp]
src = features[-1]
output = {
"vision_features": src,
"vision_pos_enc": pos,
"backbone_fpn": features,
}
return output
class FpnNeck(nn.Module):
"""Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models."""
def __init__(
self,
d_model: int,
backbone_channel_list: List[int],
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
fpn_interp_model: str = "bilinear",
fuse_type: str = "sum",
fpn_top_down_levels: Optional[List[int]] = None,
):
"""
Initializes a modified Feature Pyramid Network (FPN) neck.
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
similar to ViT positional embedding interpolation.
Args:
d_model (int): Dimension of the model.
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
kernel_size (int): Kernel size for the convolutional layers.
stride (int): Stride for the convolutional layers.
padding (int): Padding for the convolutional layers.
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
Attributes:
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding.
convs (nn.ModuleList): List of convolutional layers for each backbone level.
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
fuse_type (str): Type of feature fusion.
fpn_top_down_levels (List[int]): Levels with top-down feature propagation.
Examples:
>>> backbone_channels = [64, 128, 256, 512]
>>> fpn_neck = FpnNeck(256, backbone_channels)
>>> print(fpn_neck)
"""
super().__init__()
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
self.convs = nn.ModuleList()
self.backbone_channel_list = backbone_channel_list
for dim in backbone_channel_list:
current = nn.Sequential()
current.add_module(
"conv",
nn.Conv2d(
in_channels=dim,
out_channels=d_model,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
)
self.convs.append(current)
self.fpn_interp_model = fpn_interp_model
assert fuse_type in ["sum", "avg"]
self.fuse_type = fuse_type
# levels to have top-down features in its outputs
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
# have top-down propagation, while outputs of level 0 and level 1 have only
# lateral features from the same backbone level.
if fpn_top_down_levels is None:
# default is to have top-down features on all levels
fpn_top_down_levels = range(len(self.convs))
self.fpn_top_down_levels = list(fpn_top_down_levels)
def forward(self, xs: List[torch.Tensor]):
"""
Performs forward pass through the Feature Pyramid Network (FPN) neck.
Args:
xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor.
Returns:
(Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists:
- out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor.
- pos: List of positional encodings corresponding to each output feature map.
Examples:
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
>>> outputs, positions = fpn_neck(inputs)
"""
out = [None] * len(self.convs)
pos = [None] * len(self.convs)
assert len(xs) == len(self.convs)
# fpn forward pass
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
prev_features = None
# forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
for i in range(n, -1, -1):
x = xs[i]
lateral_features = self.convs[n - i](x)
if i in self.fpn_top_down_levels and prev_features is not None:
top_down_features = F.interpolate(
prev_features.to(dtype=torch.float32),
scale_factor=2.0,
mode=self.fpn_interp_model,
align_corners=(None if self.fpn_interp_model == "nearest" else False),
antialias=False,
)
prev_features = lateral_features + top_down_features
if self.fuse_type == "avg":
prev_features /= 2
else:
prev_features = lateral_features
x_out = prev_features
out[i] = x_out
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
return out, pos
class Hiera(nn.Module):
"""Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks."""
def __init__(
self,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
drop_path_rate: float = 0.0, # stochastic depth
q_pool: int = 3, # number of q_pool stages
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
dim_mul: float = 2.0, # dim_mul factor at stage shift
head_mul: float = 2.0, # head_mul factor at stage shift
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
# window size per stage, when not using global att.
window_spec: Tuple[int, ...] = (
8,
4,
14,
7,
),
# global attn in these blocks
global_att_blocks: Tuple[int, ...] = (
12,
16,
20,
),
return_interm_layers=True, # return feats from every stage
):
"""Initializes a Hiera model with configurable architecture for hierarchical vision transformers."""
super().__init__()
assert len(stages) == len(window_spec)
self.window_spec = window_spec
depth = sum(stages)
self.q_stride = q_stride
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
assert 0 <= q_pool <= len(self.stage_ends[:-1])
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
self.return_interm_layers = return_interm_layers
self.patch_embed = PatchEmbed(
embed_dim=embed_dim,
kernel_size=(7, 7),
stride=(4, 4),
padding=(3, 3),
)
# Which blocks have global att?
self.global_att_blocks = global_att_blocks
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
cur_stage = 1
self.blocks = nn.ModuleList()
for i in range(depth):
dim_out = embed_dim
# lags by a block, so first block of
# next stage uses an initial window size
# of previous stage and final window size of current stage
window_size = self.window_spec[cur_stage - 1]
if self.global_att_blocks is not None:
window_size = 0 if i in self.global_att_blocks else window_size
if i - 1 in self.stage_ends:
dim_out = int(embed_dim * dim_mul)
num_heads = int(num_heads * head_mul)
cur_stage += 1
block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
drop_path=dpr[i],
q_stride=self.q_stride if i in self.q_pool_blocks else None,
window_size=window_size,
)
embed_dim = dim_out
self.blocks.append(block)
self.channel_list = (
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
if return_interm_layers
else [self.blocks[-1].dim_out]
)
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
"""Generate positional embeddings by interpolating and combining window and background embeddings."""
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Performs hierarchical vision transformer forward pass, returning multiscale feature maps."""
x = self.patch_embed(x)
# x: (B, H, W, C)
# Add pos embed
x = x + self._get_pos_embed(x.shape[1:3])
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs

@ -1,804 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torch.distributed
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from ultralytics.models.sam.modules.encoders import PromptEncoder
from ultralytics.nn.modules import MLP
from .decoders import MaskDecoder
from .sam2_blocks import TwoWayTransformer
from .utils import get_1d_sine_pe, select_closest_cond_frames
# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0
class SAM2Model(torch.nn.Module):
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities."""
mask_threshold: float = 0.0
def __init__(
self,
image_encoder,
memory_attention,
memory_encoder,
num_maskmem=7, # default 1 input frame + 6 previous frames
image_size=512,
backbone_stride=16, # stride of the image backbone output
sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
# During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
binarize_mask_from_pts_for_mem_enc=False,
use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
# The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
# we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
# a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
max_cond_frames_in_attn=-1,
# on the first frame, whether to directly add the no-memory embedding to the image feature
# (instead of using the transformer encoder)
directly_add_no_mem_embed=False,
# whether to use high-resolution feature maps in the SAM mask decoder
use_high_res_features_in_sam=False,
# whether to output multiple (3) masks for the first click on initial conditioning frames
multimask_output_in_sam=False,
# the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
# default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
multimask_min_pt_num=1,
multimask_max_pt_num=1,
# whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
multimask_output_for_tracking=False,
# Whether to use multimask tokens for obj ptr; Only relevant when both
# use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
use_multimask_token_for_obj_ptr: bool = False,
# whether to use sigmoid to restrict ious prediction to [0-1]
iou_prediction_use_sigmoid=False,
# The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
memory_temporal_stride_for_eval=1,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
non_overlap_masks_for_mem_enc=False,
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder=False,
# the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
max_obj_ptrs_in_encoder=16,
# whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
add_tpos_enc_to_obj_ptrs=True,
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
proj_tpos_enc_in_obj_ptrs=False,
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
only_obj_ptrs_in_the_past_for_eval=False,
# Whether to predict if there is an object in the frame
pred_obj_scores: bool = False,
# Whether to use an MLP to predict object scores
pred_obj_scores_mlp: bool = False,
# Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
# Whether to have a fixed no obj pointer when there is no object present
# or to use it as an additive embedding with obj_ptr produced by decoder
fixed_no_obj_ptr: bool = False,
# Soft no object, i.e. mix in no_obj_ptr softly,
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
soft_no_obj_ptr: bool = False,
use_mlp_for_obj_ptr_proj: bool = False,
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False,
):
"""Initializes SAM2Model model with image encoder, memory attention, and memory encoder components."""
super().__init__()
# Part 1: the image backbone
self.image_encoder = image_encoder
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
self.use_high_res_features_in_sam = use_high_res_features_in_sam
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
if use_obj_ptrs_in_encoder:
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
# so that it can be fed into the SAM mask decoder to generate a pointer.
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
if proj_tpos_enc_in_obj_ptrs:
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
# Part 2: memory attention to condition current frame's visual features
# with memories (and obj ptrs) from past frames
self.memory_attention = memory_attention
self.hidden_dim = memory_attention.d_model
# Part 3: memory encoder for the previous frame's outputs
self.memory_encoder = memory_encoder
self.mem_dim = self.hidden_dim
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
# if there is compression of memories along channel dim
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
self.num_maskmem = num_maskmem # Number of memories accessible
# Temporal encoding of the memories
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
# a single token to indicate no memory embedding from previous frames
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
trunc_normal_(self.no_mem_embed, std=0.02)
trunc_normal_(self.no_mem_pos_enc, std=0.02)
self.directly_add_no_mem_embed = directly_add_no_mem_embed
# Apply sigmoid to the output raw mask logits (to turn them from
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
# On frames with mask input, whether to directly output the input mask without
# using a SAM prompt encoder + mask decoder
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
self.multimask_output_in_sam = multimask_output_in_sam
self.multimask_min_pt_num = multimask_min_pt_num
self.multimask_max_pt_num = multimask_max_pt_num
self.multimask_output_for_tracking = multimask_output_for_tracking
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
# and SAM-style mask decoder for the final mask output
self.image_size = image_size
self.backbone_stride = backbone_stride
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
self.pred_obj_scores = pred_obj_scores
self.pred_obj_scores_mlp = pred_obj_scores_mlp
self.fixed_no_obj_ptr = fixed_no_obj_ptr
self.soft_no_obj_ptr = soft_no_obj_ptr
if self.fixed_no_obj_ptr:
assert self.pred_obj_scores
assert self.use_obj_ptrs_in_encoder
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
trunc_normal_(self.no_obj_ptr, std=0.02)
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
self._build_sam_heads()
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self.max_cond_frames_in_attn = max_cond_frames_in_attn
# Model compilation
if compile_image_encoder:
# Compile the forward function (not the full module) to allow loading checkpoints.
print("Image encoder compilation is enabled. First forward pass will be slow.")
self.image_encoder.forward = torch.compile(
self.image_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
@property
def device(self):
"""Returns the device on which the model's parameters are stored."""
return next(self.parameters()).device
def forward(self, *args, **kwargs):
"""Processes input frames and prompts to generate object masks and scores in video sequences."""
raise NotImplementedError(
"Please use the corresponding methods in SAM2VideoPredictor for inference."
"See notebooks/video_predictor_example.ipynb for an example."
)
def _build_sam_heads(self):
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
self.sam_prompt_embed_dim = self.hidden_dim
self.sam_image_embedding_size = self.image_size // self.backbone_stride
# build PromptEncoder and MaskDecoder from SAM
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
self.sam_prompt_encoder = PromptEncoder(
embed_dim=self.sam_prompt_embed_dim,
image_embedding_size=(
self.sam_image_embedding_size,
self.sam_image_embedding_size,
),
input_image_size=(self.image_size, self.image_size),
mask_in_chans=16,
)
self.sam_mask_decoder = MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=self.sam_prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.sam_prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
use_high_res_features=self.use_high_res_features_in_sam,
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
pred_obj_scores=self.pred_obj_scores,
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
**(self.sam_mask_decoder_extra_args or {}),
)
if self.use_obj_ptrs_in_encoder:
# a linear projection on SAM output tokens to turn them into object pointers
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
if self.use_mlp_for_obj_ptr_proj:
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
else:
self.obj_ptr_proj = torch.nn.Identity()
if self.proj_tpos_enc_in_obj_ptrs:
# a linear projection on temporal positional encoding in object pointers to
# avoid potential interference with spatial positional encoding
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
else:
self.obj_ptr_tpos_proj = torch.nn.Identity()
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
Forward SAM prompt encoders and mask heads.
Args:
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
pixel-unit coordinates in (x, y) format for P input points.
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
0 means negative clicks, and -1 means padding.
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
same spatial size as the image.
high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
for SAM decoder.
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
output only 1 mask and its IoU estimate.
Returns:
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
object_score_logits: Tensor of shape (B,) with object score logits.
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
Examples:
>>> backbone_features = torch.rand(1, 256, 32, 32)
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
>>> mask_inputs = torch.rand(1, 1, 512, 512)
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
"""
B = backbone_features.size(0)
device = backbone_features.device
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# a) Handle point prompts
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"]
sam_point_labels = point_inputs["point_labels"]
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# If no points are provide, pad with an empty point (with label -1)
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# b) Handle mask prompts
if mask_inputs is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
else:
sam_mask_prompt = mask_inputs
else:
# Otherwise, simply feed None (and SAM's prompt encoder will add
# a learned `no_mask_embed` to indicate no mask input in this case).
sam_mask_prompt = None
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None,
masks=sam_mask_prompt,
)
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
high_res_features=high_res_features,
)
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
sam_output_token = sam_output_tokens[:, 0]
if multimask_output:
# take the best mask prediction (with the highest IoU estimation)
best_iou_inds = torch.argmax(ious, dim=-1)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# Extract object pointer from the SAM output token (with occlusion handling)
obj_ptr = self.obj_ptr_proj(sam_output_token)
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
# Only hard possible with gt
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float()
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
"""Processes mask inputs to generate output mask logits and object pointers without using SAM."""
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
mask_inputs_float = mask_inputs.float()
high_res_masks = mask_inputs_float * out_scale + out_bias
low_res_masks = F.interpolate(
high_res_masks,
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
# a dummy IoU prediction of all 1's under mask input
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
if not self.use_obj_ptrs_in_encoder:
# all zeros as a dummy object pointer (of shape [B, C])
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
else:
# produce an object pointer using the SAM decoder from the mask input
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
backbone_features=backbone_features,
mask_inputs=self.mask_downsample(mask_inputs_float),
high_res_features=high_res_features,
)
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
# on the object_scores from the SAM decoder.
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
is_obj_appearing = is_obj_appearing[..., None]
lambda_is_obj_appearing = is_obj_appearing.float()
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
if self.pred_obj_scores:
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_masks,
high_res_masks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def forward_image(self, img_batch: torch.Tensor):
"""Process image batch through encoder to extract multi-level features for SAM model."""
backbone_out = self.image_encoder(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
return backbone_out
def _prepare_backbone_features(self, backbone_out):
"""Prepare and flatten visual features from the image backbone output."""
backbone_out = backbone_out.copy()
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
def _prepare_memory_conditioned_features(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
):
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
device = current_vision_feats[-1].device
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
# In this case, we skip the fusion with any memory.
if self.num_maskmem == 0: # Disable memory and skip fusion
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
return pix_feat
num_obj_ptr_tokens = 0
# Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame:
# Retrieve the memories encoded with the maskmem backbone
to_cat_memory, to_cat_memory_pos_embed = [], []
# Add conditioning frames's output first (all cond frames have t_pos=0 for
# when getting temporal positional embedding below)
assert len(output_dict["cond_frame_outputs"]) > 0
# Select a maximum number of temporally closest cond frames for cross attention
cond_outputs = output_dict["cond_frame_outputs"]
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
frame_idx, cond_outputs, self.max_cond_frames_in_attn
)
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with r>1), in which case
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
r = self.memory_temporal_stride_for_eval
for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1:
# for t_rel == 1, we take the last frame (regardless of r)
if not track_in_reverse:
# the frame immediately before this frame (i.e. frame_idx - 1)
prev_frame_idx = frame_idx - t_rel
else:
# the frame immediately after this frame (i.e. frame_idx + 1)
prev_frame_idx = frame_idx + t_rel
else:
# for t_rel >= 2, we take the memory frame from every r-th frames
if not track_in_reverse:
# first find the nearest frame among every r-th frames before this frame
# for r=1, this would be (frame_idx - 2)
prev_frame_idx = ((frame_idx - 2) // r) * r
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
else:
# first find the nearest frame among every r-th frames after this frame
# for r=1, this would be (frame_idx + 2)
prev_frame_idx = -(-(frame_idx + 2) // r) * r
# then seek further among every r-th frames
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
if out is None:
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
# frames, we still attend to it as if it's a non-conditioning frame.
out = unselected_cond_outputs.get(prev_frame_idx, None)
t_pos_and_prevs.append((t_pos, out))
for t_pos, prev in t_pos_and_prevs:
if prev is None:
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
to_cat_memory_pos_embed.append(maskmem_enc)
# Construct the list of past object pointers
if self.use_obj_ptrs_in_encoder:
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
# First add those object pointers from selected conditioning frames
# (optionally, only include object pointers in the past during evaluation)
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
ptr_cond_outputs = {
t: out
for t, out in selected_cond_outputs.items()
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
}
else:
ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame
(abs(frame_idx - t), out["obj_ptr"])
for t, out in ptr_cond_outputs.items()
]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
for t_diff in range(1, max_obj_ptrs_in_encoder):
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
if t < 0 or (num_frames is not None and t >= num_frames):
break
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
if out is not None:
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
# If we have at least one object pointer, add them to the across attention
if len(pos_and_ptrs) > 0:
pos_list, ptrs_list = zip(*pos_and_ptrs)
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
obj_ptrs = torch.stack(ptrs_list, dim=0)
# a temporal positional embedding based on how far each object pointer is from
# the current frame (sine embedding normalized by the max pointer num).
if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list, device=device)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
else:
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
if self.mem_dim < C:
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
to_cat_memory.append(obj_ptrs)
to_cat_memory_pos_embed.append(obj_pos)
num_obj_ptr_tokens = obj_ptrs.shape[0]
else:
num_obj_ptr_tokens = 0
else:
# for initial conditioning frames, encode them without using any previous memory
if self.directly_add_no_mem_embed:
# directly add no-mem embedding (instead of using the transformer encoder)
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
# Step 2: Concatenate the memories and forward through the transformer encoder
memory = torch.cat(to_cat_memory, dim=0)
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats,
curr_pos=current_vision_pos_embeds,
memory=memory,
memory_pos=memory_pos_embed,
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
# reshape the output (HW)BC => BCHW
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
def _encode_new_memory(
self,
current_vision_feats,
feat_sizes,
pred_masks_high_res,
is_mask_from_pts,
):
"""Encodes the current frame's features and predicted masks into a new memory representation."""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
# top-level feature, (HW)BC => BCHW
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
if self.non_overlap_masks_for_mem_enc and not self.training:
# optionally, apply non-overlapping constraints to the masks (it's applied
# in the batch dimension and should only be used during eval, where all
# the objects come from the same video under batch size 1).
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
# scale the raw mask logits with a temperature before applying sigmoid
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
if binarize and not self.training:
mask_for_mem = (pred_masks_high_res > 0).float()
else:
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
mask_for_mem = torch.sigmoid(pred_masks_high_res)
# apply scale and bias terms to the sigmoid probabilities
if self.sigmoid_scale_for_mem_enc != 1.0:
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat,
mask_for_mem,
skip_mask_sigmoid=True, # sigmoid already applied
)
maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
return maskmem_features, maskmem_pos_enc
def track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
# to skip the memory encoder with `run_mem_encoder=False`. For example,
# in demo we might call `track_step` multiple times for each user click,
# and only encode the memory when the user finalizes their clicks. And in ablation
# settings like SAM training on static images, we don't need the memory encoder.
run_mem_encoder=True,
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits=None,
):
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat_with_mem = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
(
_,
_,
_,
low_res_masks,
high_res_masks,
obj_ptr,
_,
) = sam_outputs
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
is_mask_from_pts=(point_inputs is not None),
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
return current_out
def _use_multimask(self, is_init_cond_frame, point_inputs):
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
multimask_output = (
self.multimask_output_in_sam
and (is_init_cond_frame or self.multimask_output_for_tracking)
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
)
return multimask_output
def _apply_non_overlapping_constraints(self, pred_masks):
"""Applies non-overlapping constraints to object masks, keeping highest scoring object at each location."""
batch_size = pred_masks.size(0)
if batch_size == 1:
return pred_masks
device = pred_masks.device
# "max_obj_inds": object index of the object with the highest score at each location
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
keep = max_obj_inds == batch_obj_inds
# suppress overlapping regions' scores below -10.0 so that the foreground regions
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
return pred_masks

@ -1,715 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import copy
import math
from functools import partial
from typing import Optional, Tuple, Type, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from ultralytics.models.sam.modules.transformer import (
Attention,
)
from ultralytics.models.sam.modules.transformer import (
TwoWayAttentionBlock as SAMTwoWayAttentionBlock,
)
from ultralytics.models.sam.modules.transformer import (
TwoWayTransformer as SAMTwoWayTransformer,
)
from ultralytics.nn.modules import MLP, LayerNorm2d
from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
class DropPath(nn.Module):
"""Implements stochastic depth regularization for neural networks during training."""
def __init__(self, drop_prob=0.0, scale_by_keep=True):
"""Initialize DropPath module with specified drop probability and scaling option."""
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
"""Applies stochastic depth to input tensor during training, with optional scaling."""
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class MaskDownSampler(nn.Module):
"""Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing."""
def __init__(
self,
embed_dim=256,
kernel_size=4,
stride=4,
padding=0,
total_stride=16,
activation=nn.GELU,
):
"""Initializes a mask downsampler module for progressive downsampling and channel expansion."""
super().__init__()
num_layers = int(math.log2(total_stride) // math.log2(stride))
assert stride**num_layers == total_stride
self.encoder = nn.Sequential()
mask_in_chans, mask_out_chans = 1, 1
for _ in range(num_layers):
mask_out_chans = mask_in_chans * (stride**2)
self.encoder.append(
nn.Conv2d(
mask_in_chans,
mask_out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
)
self.encoder.append(LayerNorm2d(mask_out_chans))
self.encoder.append(activation())
mask_in_chans = mask_out_chans
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
def forward(self, x):
"""Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
return self.encoder(x)
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
class CXBlock(nn.Module):
"""
ConvNeXt Block for efficient feature extraction in convolutional neural networks.
This block implements a modified version of the ConvNeXt architecture, offering two equivalent
implementations for improved performance and flexibility.
Attributes:
dwconv (nn.Conv2d): Depthwise convolution layer.
norm (LayerNorm2d): Layer normalization applied to channels.
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
act (nn.GELU): GELU activation function.
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
drop_path (nn.Module): DropPath layer for stochastic depth regularization.
Methods:
forward: Processes the input tensor through the ConvNeXt block.
Examples:
>>> import torch
>>> x = torch.randn(1, 64, 56, 56)
>>> block = CXBlock(dim=64, kernel_size=7, padding=3)
>>> output = block(x)
>>> print(output.shape)
torch.Size([1, 64, 56, 56])
"""
def __init__(
self,
dim,
kernel_size=7,
padding=3,
drop_path=0.0,
layer_scale_init_value=1e-6,
use_dwconv=True,
):
"""
Initialize a ConvNeXt Block.
This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization,
pointwise convolutions, and GELU activation.
Args:
dim (int): Number of input channels.
kernel_size (int): Size of the convolutional kernel. Default is 7.
padding (int): Padding size for the convolution. Default is 3.
drop_path (float): Stochastic depth rate. Default is 0.0.
layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6.
use_dwconv (bool): Whether to use depthwise convolution. Default is True.
Attributes:
dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
norm (LayerNorm2d): Layer normalization applied to the output of dwconv.
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
act (nn.GELU): GELU activation function.
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
gamma (nn.Parameter | None): Learnable scale parameter for the residual path.
Examples:
>>> block = CXBlock(dim=64, kernel_size=7, padding=3)
>>> x = torch.randn(1, 64, 32, 32)
>>> output = block(x)
>>> print(output.shape)
torch.Size([1, 64, 32, 32])
"""
super().__init__()
self.dwconv = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=padding,
groups=dim if use_dwconv else 1,
) # depthwise conv
self.norm = LayerNorm2d(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
"""Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
input = x
x = self.dwconv(x)
x = self.norm(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class Fuser(nn.Module):
"""
A module for fusing features through multiple layers of a neural network.
This class applies a series of identical layers to an input tensor, optionally projecting the input first.
Attributes:
proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
layers (nn.ModuleList): A list of identical layers to be applied sequentially.
Methods:
forward: Applies the fuser to an input tensor.
Examples:
>>> layer = CXBlock(dim=256)
>>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
>>> x = torch.randn(1, 256, 32, 32)
>>> output = fuser(x)
>>> print(output.shape)
torch.Size([1, 256, 32, 32])
"""
def __init__(self, layer, num_layers, dim=None, input_projection=False):
"""
Initializes the Fuser module.
This module creates a sequence of identical layers and optionally applies an input projection.
Args:
layer (nn.Module): The layer to be replicated in the fuser.
num_layers (int): The number of times to replicate the layer.
dim (int | None): The dimension for input projection, if used.
input_projection (bool): Whether to use input projection.
Attributes:
proj (nn.Module): The input projection layer, or nn.Identity if not used.
layers (nn.ModuleList): A list of replicated layers.
Examples:
>>> layer = nn.Linear(64, 64)
>>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
>>> input_tensor = torch.randn(1, 64)
>>> output = fuser(input_tensor)
"""
super().__init__()
self.proj = nn.Identity()
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
if input_projection:
assert dim is not None
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
def forward(self, x):
"""Applies a series of layers to the input tensor, optionally projecting it first."""
x = self.proj(x)
for layer in self.layers:
x = layer(x)
return x
class TwoWayAttentionBlock(SAMTwoWayAttentionBlock):
"""
A two-way attention block for performing self-attention and cross-attention in both directions.
This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on
sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
cross-attention from dense to sparse inputs.
Attributes:
self_attn (Attention): Self-attention layer for queries.
norm1 (nn.LayerNorm): Layer normalization after the first attention block.
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
norm2 (nn.LayerNorm): Layer normalization after the second attention block.
mlp (MLP): MLP block for transforming query embeddings.
norm3 (nn.LayerNorm): Layer normalization after the MLP block.
norm4 (nn.LayerNorm): Layer normalization after the third attention block.
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
Methods:
forward: Processes input through the attention blocks and MLP.
Examples:
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
>>> sparse_input = torch.randn(1, 100, 256)
>>> dense_input = torch.randn(1, 256, 16, 16)
>>> sparse_output, dense_output = block(sparse_input, dense_input)
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs
to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.
Args:
embedding_dim (int): The channel dimension of the embeddings.
num_heads (int): The number of heads in the attention layers.
mlp_dim (int): The hidden dimension of the MLP block.
activation (Type[nn.Module]): The activation function of the MLP block.
attention_downsample_rate (int): The downsample rate for attention computations.
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
Attributes:
self_attn (Attention): The self-attention layer for the queries.
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
mlp (MLP): MLP block that transforms the query embeddings.
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
Examples:
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> sparse_inputs = torch.randn(1, 100, 256)
>>> dense_inputs = torch.randn(1, 256, 32, 32)
>>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
"""
super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
class TwoWayTransformer(SAMTwoWayTransformer):
"""
A Two-Way Transformer module for simultaneous attention to image and query points.
This class implements a specialized transformer decoder that attends to an input image using queries with
supplied positional embeddings. It is particularly useful for tasks like object detection, image
segmentation, and point cloud processing.
Attributes:
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention.
mlp_dim (int): Internal channel dimension for the MLP block.
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
final_attn_token_to_image (Attention): Final attention layer from queries to image.
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
Methods:
forward: Processes input image embeddings and query embeddings through the transformer.
Examples:
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 64, 64)
>>> query_embedding = torch.randn(1, 100, 256)
>>> output = transformer(image_embedding, query_embedding)
"""
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
Initializes a TwoWayTransformer instance.
This transformer decoder attends to an input image using queries with supplied positional embeddings.
It is designed for tasks like object detection, image segmentation, and point cloud processing.
Args:
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for the input embeddings.
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
mlp_dim (int): Channel dimension internal to the MLP block.
activation (Type[nn.Module]): Activation function to use in the MLP block.
attention_downsample_rate (int): Downsampling rate for attention computations.
Attributes:
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for the input embeddings.
num_heads (int): Number of heads for multihead attention.
mlp_dim (int): Internal channel dimension for the MLP block.
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
final_attn_token_to_image (Attention): Final attention layer from queries to image.
norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries.
Examples:
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> transformer
TwoWayTransformer(
(layers): ModuleList(
(0-4): 5 x TwoWayAttentionBlock(...)
)
(final_attn_token_to_image): Attention(...)
(norm_final_attn): LayerNorm(...)
)
"""
super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
class RoPEAttention(Attention):
"""Implements rotary position encoding for attention mechanisms in transformer architectures."""
def __init__(
self,
*args,
rope_theta=10000.0,
# whether to repeat q rope to match k length
# this is needed for cross-attention to memories
rope_k_repeat=False,
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
**kwargs,
):
"""Initializes RoPEAttention with rotary position encoding for attention mechanisms."""
super().__init__(*args, **kwargs)
self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = freqs_cis
self.rope_k_repeat = rope_k_repeat
def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
"""Applies rotary position encoding and computes attention between query, key, and value tensors."""
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Apply rotary position encoding
w = h = math.sqrt(q.shape[-2])
self.freqs_cis = self.freqs_cis.to(q.device)
if self.freqs_cis.shape[0] != q.shape[-2]:
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
if q.shape[-2] != k.shape[-2]:
assert self.rope_k_repeat
num_k_rope = k.size(-2) - num_k_exclude_rope
q, k[:, :, :num_k_rope] = apply_rotary_enc(
q,
k[:, :, :num_k_rope],
freqs_cis=self.freqs_cis,
repeat_freqs_k=self.rope_k_repeat,
)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
"""Applies pooling and optional normalization to a tensor, handling permutations for spatial operations."""
if pool is None:
return x
# (B, H, W, C) -> (B, C, H, W)
x = x.permute(0, 3, 1, 2)
x = pool(x)
# (B, C, H', W') -> (B, H', W', C)
x = x.permute(0, 2, 3, 1)
if norm:
x = norm(x)
return x
class MultiScaleAttention(nn.Module):
"""Implements multi-scale self-attention with optional query pooling for efficient feature extraction."""
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
q_pool: nn.Module = None,
):
"""Initializes a multi-scale attention module with configurable query pooling and linear projections."""
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5
self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies multi-scale attention to input tensor, optionally downsampling query features."""
B, H, W, _ = x.shape
# qkv with shape (B, H * W, 3, nHead, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
# q, k, v with shape (B, H * W, nheads, C)
q, k, v = torch.unbind(qkv, 2)
# Q pooling (for downsample at stage changes)
if self.q_pool:
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
H, W = q.shape[1:3] # downsampled shape
q = q.reshape(B, H * W, self.num_heads, -1)
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
x = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
)
# Transpose back
x = x.transpose(1, 2)
x = x.reshape(B, H, W, -1)
x = self.proj(x)
return x
class MultiScaleBlock(nn.Module):
"""Multiscale attention block with window partitioning and query pooling for efficient vision transformers."""
def __init__(
self,
dim: int,
dim_out: int,
num_heads: int,
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
norm_layer: Union[nn.Module, str] = "LayerNorm",
q_stride: Tuple[int, int] = None,
act_layer: nn.Module = nn.GELU,
window_size: int = 0,
):
"""Initializes a multi-scale attention block with optional window partitioning and downsampling."""
super().__init__()
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
self.window_size = window_size
self.pool, self.q_stride = None, q_stride
if self.q_stride:
self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
self.attn = MultiScaleAttention(
dim,
dim_out,
num_heads=num_heads,
q_pool=self.pool,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim_out)
self.mlp = MLP(
dim_out,
int(dim_out * mlp_ratio),
dim_out,
num_layers=2,
act=act_layer,
)
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies multi-scale attention and MLP processing to input tensor, with optional windowing."""
shortcut = x # B, H, W, C
x = self.norm1(x)
# Skip connection
if self.dim != self.dim_out:
shortcut = do_pool(self.proj(x), self.pool)
# Window partition
window_size = self.window_size
if window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, window_size)
# Window Attention + Q Pooling (if stage change)
x = self.attn(x)
if self.q_stride:
# Shapes have changed due to Q pooling
window_size = self.window_size // self.q_stride[0]
H, W = shortcut.shape[1:3]
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
pad_hw = (H + pad_h, W + pad_w)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, window_size, pad_hw, (H, W))
x = shortcut + self.drop_path(x)
# MLP
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PositionEmbeddingSine(nn.Module):
"""Generates sinusoidal positional embeddings for 2D inputs like images."""
def __init__(
self,
num_pos_feats,
temperature: int = 10000,
normalize: bool = True,
scale: Optional[float] = None,
):
"""Initializes sinusoidal position embeddings for 2D image inputs."""
super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width"
self.num_pos_feats = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
self.cache = {}
def _encode_xy(self, x, y):
"""Encodes 2D positions using sine and cosine functions for positional embeddings."""
assert len(x) == len(y) and x.ndim == y.ndim == 1
x_embed = x * self.scale
y_embed = y * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y
@torch.no_grad()
def encode_boxes(self, x, y, w, h):
"""Encodes box coordinates and dimensions into positional embeddings for object detection tasks."""
pos_x, pos_y = self._encode_xy(x, y)
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
return pos
encode = encode_boxes # Backwards compatibility
@torch.no_grad()
def encode_points(self, x, y, labels):
"""Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels."""
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
assert bx == by and nx == ny and bx == bl and nx == nl
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
"""Generate sinusoidal position embeddings for 2D inputs."""
cache_key = (x.shape[-2], x.shape[-1])
if cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
.view(1, -1, 1)
.repeat(x.shape[0], 1, x.shape[-1])
)
x_embed = (
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
.view(1, 1, -1)
.repeat(x.shape[0], x.shape[-2], 1)
)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
self.cache[cache_key] = pos[0]
return pos

@ -1,177 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ..sam.predict import Predictor
from .build import build_sam2
class SAM2Predictor(Predictor):
"""
A predictor class for the Segment Anything Model 2 (SAM2), extending the base Predictor class.
This class provides an interface for model inference tailored to image segmentation tasks, leveraging SAM2's
advanced architecture and promptable segmentation capabilities. It facilitates flexible and real-time mask
generation, working with various types of prompts such as bounding boxes, points, and low-resolution masks.
Attributes:
cfg (Dict): Configuration dictionary specifying model and task-related parameters.
overrides (Dict): Dictionary containing values that override the default configuration.
_callbacks (Dict): Dictionary of user-defined callback functions to augment behavior.
args (namespace): Namespace to hold command-line arguments or other operational variables.
im (torch.Tensor): Preprocessed input image tensor.
features (torch.Tensor): Extracted image features used for inference.
prompts (Dict): Collection of various prompt types, such as bounding boxes and points.
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
model (torch.nn.Module): The loaded SAM2 model.
device (torch.device): The device (CPU or GPU) on which the model is loaded.
_bb_feat_sizes (List[Tuple[int, int]]): List of feature sizes for different backbone levels.
Methods:
get_model: Builds and returns the SAM2 model.
prompt_inference: Performs image segmentation inference based on various prompts.
set_image: Preprocesses and sets a single image for inference.
get_im_features: Extracts image features from the SAM2 image encoder.
Examples:
>>> predictor = SAM2Predictor(model='sam2_l.pt')
>>> predictor.set_image('path/to/image.jpg')
>>> masks, scores = predictor.prompt_inference(im=predictor.im, points=[[500, 375]], labels=[1])
>>> print(f"Generated {len(masks)} mask(s) with scores: {scores}")
"""
_bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]
def get_model(self):
"""Retrieves and initializes the Segment Anything Model (SAM) for image segmentation tasks."""
return build_sam2(self.args.model)
def prompt_inference(
self,
im,
bboxes=None,
points=None,
labels=None,
masks=None,
multimask_output=False,
img_idx=-1,
):
"""
Performs image segmentation inference based on various prompts using SAM2 architecture.
Args:
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | List | None): Labels for point prompts with shape (N,). 1 = foreground, 0 = background.
masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
img_idx (int): Index of the image in the batch to process.
Returns:
(tuple): Tuple containing:
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
- np.ndarray: Quality scores for each mask, with length C.
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
Examples:
>>> predictor = SAM2Predictor(cfg)
>>> image = torch.rand(1, 3, 640, 640)
>>> bboxes = [[100, 100, 200, 200]]
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
"""
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = torch.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes = bboxes.view(-1, 2, 2) * r
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
points = (points, labels) if points is not None else None
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=points,
boxes=None,
masks=masks,
)
# Predict masks
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=batched_mode,
high_res_features=high_res_features,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def set_image(self, image):
"""
Preprocesses and sets a single image for inference.
This function sets up the model if not already initialized, configures the data source to the specified image,
and preprocesses the image for feature extraction. Only one image can be set at a time.
Args:
image (str | np.ndarray): Image file path as a string, or a numpy array image read by cv2.
Raises:
AssertionError: If more than one image is set.
Examples:
>>> predictor = SAM2Predictor()
>>> predictor.set_image("path/to/image.jpg")
>>> predictor.set_image(np.array([...])) # Using a numpy array
"""
if self.model is None:
self.setup_model(model=None)
self.setup_source(image)
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.get_im_features(im)
break
def get_im_features(self, im):
"""Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks."""
backbone_out = self.model.forward_image(im)
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
Loading…
Cancel
Save