diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 0f183abafd..3c0c255081 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -168,6 +168,7 @@ jobs:
PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
INDEXNOW_KEY: ${{ secrets.INDEXNOW_KEY_DOCS }}
run: |
+ pip install black
export JUPYTER_PLATFORM_DIRS=1
python docs/build_docs.py
git clone https://github.com/ultralytics/docs.git docs-repo
diff --git a/docs/en/guides/model-monitoring-and-maintenance.md b/docs/en/guides/model-monitoring-and-maintenance.md
index 7074de6de7..a301cf16b9 100644
--- a/docs/en/guides/model-monitoring-and-maintenance.md
+++ b/docs/en/guides/model-monitoring-and-maintenance.md
@@ -135,3 +135,38 @@ Using these resources will help you solve challenges and stay up-to-date with th
## Key Takeaways
We covered key tips for monitoring, maintaining, and documenting your computer vision models. Regular updates and re-training help the model adapt to new data patterns. Detecting and fixing data drift helps your model stay accurate. Continuous monitoring catches issues early, and good documentation makes collaboration and future updates easier. Following these steps will help your computer vision project stay successful and effective over time.
+
+## FAQ
+
+### How do I monitor the performance of my deployed computer vision model?
+
+Monitoring the performance of your deployed computer vision model is crucial to ensure its accuracy and reliability over time. You can use tools like [Prometheus](https://prometheus.io/), [Grafana](https://grafana.com/), and [Evidently AI](https://www.evidentlyai.com/) to track key metrics, detect anomalies, and identify data drift. Regularly monitor inputs and outputs, set up alerts for unusual behavior, and use diverse data sources to get a comprehensive view of your model's performance. For more details, check out our section on [Model Monitoring](#model-monitoring-is-key).
+
+### What are the best practices for maintaining computer vision models after deployment?
+
+Maintaining computer vision models involves regular updates, retraining, and monitoring to ensure continued accuracy and relevance. Best practices include:
+
+- **Continuous Monitoring**: Track performance metrics and data quality regularly.
+- **Data Drift Detection**: Use statistical techniques to identify changes in data distributions.
+- **Regular Updates and Retraining**: Implement incremental learning or periodic full retraining based on data changes.
+- **Documentation**: Maintain detailed documentation of model architecture, training processes, and evaluation metrics. For more insights, visit our [Model Maintenance](#model-maintenance) section.
+
+### Why is data drift detection important for AI models?
+
+Data drift detection is essential because it helps identify when the statistical properties of the input data change over time, which can degrade model performance. Techniques like continuous monitoring, statistical tests (e.g., Kolmogorov-Smirnov test), and feature drift analysis can help spot issues early. Addressing data drift ensures that your model remains accurate and relevant in changing environments. Learn more about data drift detection in our [Data Drift Detection](#data-drift-detection) section.
+
+### What tools can I use for anomaly detection in computer vision models?
+
+For anomaly detection in computer vision models, tools like [Prometheus](https://prometheus.io/), [Grafana](https://grafana.com/), and [Evidently AI](https://www.evidentlyai.com/) are highly effective. These tools can help you set up alert systems to detect unusual data points or patterns that deviate from expected behavior. Configurable alerts and standardized messages can help you respond quickly to potential issues. Explore more in our [Anomaly Detection and Alert Systems](#anomaly-detection-and-alert-systems) section.
+
+### How can I document my computer vision project effectively?
+
+Effective documentation of a computer vision project should include:
+
+- **Project Overview**: High-level summary, problem statement, and solution approach.
+- **Model Architecture**: Details of the model structure, components, and hyperparameters.
+- **Data Preparation**: Information on data sources, preprocessing steps, and transformations.
+- **Training Process**: Description of the training procedure, datasets used, and challenges encountered.
+- **Evaluation Metrics**: Metrics used for performance evaluation and analysis.
+- **Deployment Steps**: Steps taken for model deployment and any specific challenges.
+- **Monitoring and Maintenance Procedure**: Plan for ongoing monitoring and maintenance. For more comprehensive guidelines, refer to our [Documentation](#documentation) section.
diff --git a/docs/en/guides/steps-of-a-cv-project.md b/docs/en/guides/steps-of-a-cv-project.md
index a734ecf29b..a1fbdb5e97 100644
--- a/docs/en/guides/steps-of-a-cv-project.md
+++ b/docs/en/guides/steps-of-a-cv-project.md
@@ -10,6 +10,17 @@ keywords: Computer Vision, AI, Object Detection, Image Classification, Instance
Computer vision is a subfield of artificial intelligence (AI) that helps computers see and understand the world like humans do. It processes and analyzes images or videos to extract information, recognize patterns, and make decisions based on that data.
+
+
+
+
+ Watch: How to Do Computer Vision Projects | A Step-by-Step Guide
+
+
Computer vision techniques like [object detection](../tasks/detect.md), [image classification](../tasks/classify.md), and [instance segmentation](../tasks/segment.md) can be applied across various industries, from [autonomous driving](https://www.ultralytics.com/solutions/ai-in-self-driving) to [medical imaging](https://www.ultralytics.com/solutions/ai-in-healthcare) to gain valuable insights.
diff --git a/docs/en/hub/index.md b/docs/en/hub/index.md
index a88fce70ca..cc5177ca18 100644
--- a/docs/en/hub/index.md
+++ b/docs/en/hub/index.md
@@ -76,3 +76,52 @@ We hope that the resources here will help you get the most out of HUB. Please br
- [**Ultralytics HUB App**](app/index.md): Learn about the Ultralytics HUB App, which allows you to run models directly on your mobile device.
- [**iOS**](app/ios.md): Explore CoreML acceleration on iPhones and iPads.
- [**Android**](app/android.md): Explore TFLite acceleration on Android devices.
+
+## FAQ
+
+### How do I get started with Ultralytics HUB for training YOLO models?
+
+To get started with [Ultralytics HUB](https://ultralytics.com/hub), follow these steps:
+
+1. **Sign Up:** Create an account on the [Ultralytics HUB](https://ultralytics.com/hub).
+2. **Upload Dataset:** Navigate to the [Datasets](datasets.md) section to upload your custom dataset.
+3. **Train Model:** Go to the [Models](models.md) section and select a pre-trained YOLOv5 or YOLOv8 model to start training.
+4. **Deploy Model:** Once trained, preview and deploy your model using the [Ultralytics HUB App](app/index.md) for real-time tasks.
+
+For a detailed guide, refer to the [Quickstart](quickstart.md) page.
+
+### What are the benefits of using Ultralytics HUB over other AI platforms?
+
+[Ultralytics HUB](https://ultralytics.com/hub) offers several unique benefits:
+
+- **User-Friendly Interface:** Intuitive design for easy dataset uploads and model training.
+- **Pre-Trained Models:** Access to a variety of pre-trained YOLOv5 and YOLOv8 models.
+- **Cloud Training:** Seamless cloud training capabilities, detailed on the [Cloud Training](cloud-training.md) page.
+- **Real-Time Deployment:** Effortlessly deploy models for real-time applications using the [Ultralytics HUB App](app/index.md).
+- **Team Collaboration:** Collaborate with your team efficiently through the [Teams](teams.md) feature.
+
+Explore more about the advantages in our [Ultralytics HUB Blog](https://www.ultralytics.com/blog/ultralytics-hub-a-game-changer-for-computer-vision).
+
+### Can I use Ultralytics HUB for object detection on mobile devices?
+
+Yes, Ultralytics HUB supports object detection on mobile devices. You can run YOLOv5 and YOLOv8 models on both iOS and Android devices using the Ultralytics HUB App. For more details:
+
+- **iOS:** Learn about CoreML acceleration on iPhones and iPads in the [iOS](app/ios.md) section.
+- **Android:** Explore TFLite acceleration on Android devices in the [Android](app/android.md) section.
+
+### How do I manage and organize my projects in Ultralytics HUB?
+
+Ultralytics HUB allows you to manage and organize your projects efficiently. You can group your models into projects for better organization. To learn more:
+
+- Visit the [Projects](projects.md) page for detailed instructions on creating, editing, and managing projects.
+- Use the [Teams](teams.md) feature to collaborate with team members and share resources.
+
+### What integrations are available with Ultralytics HUB?
+
+Ultralytics HUB offers seamless integrations with various platforms to enhance your machine learning workflows. Some key integrations include:
+
+- **Roboflow:** For dataset management and model training. Learn more on the [Integrations](integrations.md) page.
+- **Google Colab:** Efficiently train models using Google Colab's cloud-based environment. Detailed steps are available in the [Google Colab](https://docs.ultralytics.com/integrations/google-colab) section.
+- **Weights & Biases:** For enhanced experiment tracking and visualization. Explore the [Weights & Biases](https://docs.ultralytics.com/integrations/weights-biases) integration.
+
+For a complete list of integrations, refer to the [Integrations](integrations.md) page.
diff --git a/docs/en/integrations/comet.md b/docs/en/integrations/comet.md
index bc7cc05af5..7c9ab2a9c5 100644
--- a/docs/en/integrations/comet.md
+++ b/docs/en/integrations/comet.md
@@ -53,7 +53,7 @@ Then, you can initialize your Comet project. Comet will automatically detect the
```python
import comet_ml
-comet_ml.init(project_name="comet-example-yolov8-coco128")
+comet_ml.login(project_name="comet-example-yolov8-coco128")
```
If you are using a Google Colab notebook, the code above will prompt you to enter your API key for initialization.
@@ -72,7 +72,7 @@ Before diving into the usage instructions, be sure to check out the range of [YO
# Load a model
model = YOLO("yolov8n.pt")
- # train the model
+ # Train the model
results = model.train(
data="coco8.yaml",
project="comet-example-yolov8-coco128",
@@ -200,7 +200,7 @@ To integrate Comet ML with Ultralytics YOLOv8, follow these steps:
```python
import comet_ml
- comet_ml.init(project_name="comet-example-yolov8-coco128")
+ comet_ml.login(project_name="comet-example-yolov8-coco128")
```
4. **Train your YOLOv8 model and log metrics**:
@@ -210,7 +210,12 @@ To integrate Comet ML with Ultralytics YOLOv8, follow these steps:
model = YOLO("yolov8n.pt")
results = model.train(
- data="coco8.yaml", project="comet-example-yolov8-coco128", batch=32, save_period=1, save_json=True, epochs=3
+ data="coco8.yaml",
+ project="comet-example-yolov8-coco128",
+ batch=32,
+ save_period=1,
+ save_json=True,
+ epochs=3,
)
```
@@ -255,7 +260,7 @@ Comet ML allows for extensive customization of its logging behavior using enviro
os.environ["COMET_EVAL_LOG_CONFUSION_MATRIX"] = "false"
```
-For more customization options, refer to the [Customizing Comet ML Logging](#customizing-comet-ml-logging) section.
+Refer to the [Customizing Comet ML Logging](#customizing-comet-ml-logging) section for more customization options.
### How do I view detailed metrics and visualizations of my YOLOv8 training on Comet ML?
diff --git a/docs/en/integrations/ibm-watsonx.md b/docs/en/integrations/ibm-watsonx.md
index da53b9c048..e6fae564d7 100644
--- a/docs/en/integrations/ibm-watsonx.md
+++ b/docs/en/integrations/ibm-watsonx.md
@@ -321,3 +321,90 @@ We explored IBM Watsonx key features, and how to train a YOLOv8 model using IBM
For further details on usage, visit [IBM Watsonx official documentation](https://www.ibm.com/watsonx).
Also, be sure to check out the [Ultralytics integration guide page](./index.md), to learn more about different exciting integrations.
+
+## FAQ
+
+### How do I train a YOLOv8 model using IBM Watsonx?
+
+To train a YOLOv8 model using IBM Watsonx, follow these steps:
+
+1. **Set Up Your Environment**: Create an IBM Cloud account and set up a Watsonx.ai project. Use a Jupyter Notebook for your coding environment.
+2. **Install Libraries**: Install necessary libraries like `torch`, `opencv`, and `ultralytics`.
+3. **Load Data**: Use the Kaggle API to load your dataset into Watsonx.
+4. **Preprocess Data**: Organize your dataset into the required directory structure and update the `.yaml` configuration file.
+5. **Train the Model**: Use the YOLO command-line interface to train your model with specific parameters like `epochs`, `batch size`, and `learning rate`.
+6. **Test and Evaluate**: Run inference to test the model and evaluate its performance using metrics like precision and recall.
+
+For detailed instructions, refer to our [YOLOv8 Model Training guide](../modes/train.md).
+
+### What are the key features of IBM Watsonx for AI model training?
+
+IBM Watsonx offers several key features for AI model training:
+
+- **Watsonx.ai**: Provides tools for AI development, including access to IBM-supported custom models and third-party models like Llama 3. It includes the Prompt Lab, Tuning Studio, and Flows Engine for comprehensive AI lifecycle management.
+- **Watsonx.data**: Supports cloud and on-premises deployments, offering centralized data access, efficient query engines like Presto and Spark, and an AI-powered semantic layer.
+- **Watsonx.governance**: Automates compliance, manages risk with alerts, and provides tools for detecting issues like bias and drift. It also includes dashboards and reporting tools for collaboration.
+
+For more information, visit the [IBM Watsonx official documentation](https://www.ibm.com/watsonx).
+
+### Why should I use IBM Watsonx for training Ultralytics YOLOv8 models?
+
+IBM Watsonx is an excellent choice for training Ultralytics YOLOv8 models due to its comprehensive suite of tools that streamline the AI lifecycle. Key benefits include:
+
+- **Scalability**: Easily scale your model training with IBM Cloud services.
+- **Integration**: Seamlessly integrate with various data sources and APIs.
+- **User-Friendly Interface**: Simplifies the development process with a collaborative and intuitive interface.
+- **Advanced Tools**: Access to powerful tools like the Prompt Lab, Tuning Studio, and Flows Engine for enhancing model performance.
+
+Learn more about [Ultralytics YOLOv8](https://github.com/ultralytics/ultralytics) and how to train models using IBM Watsonx in our [integration guide](./index.md).
+
+### How can I preprocess my dataset for YOLOv8 training on IBM Watsonx?
+
+To preprocess your dataset for YOLOv8 training on IBM Watsonx:
+
+1. **Organize Directories**: Ensure your dataset follows the YOLO directory structure with separate subdirectories for images and labels within the train/val/test split.
+2. **Update .yaml File**: Modify the `.yaml` configuration file to reflect the new directory structure and class names.
+3. **Run Preprocessing Script**: Use a Python script to reorganize your dataset and update the `.yaml` file accordingly.
+
+Here's a sample script to organize your dataset:
+
+```python
+import os
+import shutil
+
+
+def organize_files(directory):
+ for subdir in ["train", "test", "val"]:
+ subdir_path = os.path.join(directory, subdir)
+ if not os.path.exists(subdir_path):
+ continue
+
+ images_dir = os.path.join(subdir_path, "images")
+ labels_dir = os.path.join(subdir_path, "labels")
+
+ os.makedirs(images_dir, exist_ok=True)
+ os.makedirs(labels_dir, exist_ok=True)
+
+ for filename in os.listdir(subdir_path):
+ if filename.endswith(".txt"):
+ shutil.move(os.path.join(subdir_path, filename), os.path.join(labels_dir, filename))
+ elif filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".jpeg"):
+ shutil.move(os.path.join(subdir_path, filename), os.path.join(images_dir, filename))
+
+
+if __name__ == "__main__":
+ directory = f"{work_dir}/trash_ICRA19/dataset"
+ organize_files(directory)
+```
+
+For more details, refer to our [data preprocessing guide](../guides/preprocessing_annotated_data.md).
+
+### What are the prerequisites for training a YOLOv8 model on IBM Watsonx?
+
+Before you start training a YOLOv8 model on IBM Watsonx, ensure you have the following prerequisites:
+
+- **IBM Cloud Account**: Create an account on IBM Cloud to access Watsonx.ai.
+- **Kaggle Account**: For loading datasets, you'll need a Kaggle account and an API key.
+- **Jupyter Notebook**: Set up a Jupyter Notebook environment within Watsonx.ai for coding and model training.
+
+For more information on setting up your environment, visit our [Ultralytics Installation guide](../quickstart.md).
diff --git a/docs/en/integrations/jupyterlab.md b/docs/en/integrations/jupyterlab.md
index 8e2a68029f..9a67394d63 100644
--- a/docs/en/integrations/jupyterlab.md
+++ b/docs/en/integrations/jupyterlab.md
@@ -108,3 +108,102 @@ We've explored how JupyterLab can be a powerful tool for experimenting with Ultr
For more details, visit the [JupyterLab FAQ Page](https://jupyterlab.readthedocs.io/en/stable/getting_started/faq.html).
Interested in more YOLOv8 integrations? Check out the [Ultralytics integration guide](./index.md) to explore additional tools and capabilities for your machine learning projects.
+
+## FAQ
+
+### How do I use JupyterLab to train a YOLOv8 model?
+
+To train a YOLOv8 model using JupyterLab:
+
+1. Install JupyterLab and the Ultralytics package:
+
+ ```python
+ pip install jupyterlab ultralytics
+ ```
+
+2. Launch JupyterLab and open a new notebook.
+
+3. Import the YOLO model and load a pretrained model:
+
+ ```python
+ from ultralytics import YOLO
+
+ model = YOLO("yolov8n.pt")
+ ```
+
+4. Train the model on your custom dataset:
+
+ ```python
+ results = model.train(data="path/to/your/data.yaml", epochs=100, imgsz=640)
+ ```
+
+5. Visualize training results using JupyterLab's built-in plotting capabilities:
+ ```python
+ %matplotlib inline
+ from ultralytics.utils.plotting import plot_results
+ plot_results(results)
+ ```
+
+JupyterLab's interactive environment allows you to easily modify parameters, visualize results, and iterate on your model training process.
+
+### What are the key features of JupyterLab that make it suitable for YOLOv8 projects?
+
+JupyterLab offers several features that make it ideal for YOLOv8 projects:
+
+1. Interactive code execution: Test and debug YOLOv8 code snippets in real-time.
+2. Integrated file browser: Easily manage datasets, model weights, and configuration files.
+3. Flexible layout: Arrange multiple notebooks, terminals, and output windows side-by-side for efficient workflow.
+4. Rich output display: Visualize YOLOv8 detection results, training curves, and model performance metrics inline.
+5. Markdown support: Document your YOLOv8 experiments and findings with rich text and images.
+6. Extension ecosystem: Enhance functionality with extensions for version control, [remote computing](google-colab.md), and more.
+
+These features allow for a seamless development experience when working with YOLOv8 models, from data preparation to model deployment.
+
+### How can I optimize YOLOv8 model performance using JupyterLab?
+
+To optimize YOLOv8 model performance in JupyterLab:
+
+1. Use the autobatch feature to determine the optimal batch size:
+
+ ```python
+ from ultralytics.utils.autobatch import autobatch
+
+ optimal_batch_size = autobatch(model)
+ ```
+
+2. Implement [hyperparameter tuning](../guides/hyperparameter-tuning.md) using libraries like Ray Tune:
+
+ ```python
+ from ultralytics.utils.tuner import run_ray_tune
+
+ best_results = run_ray_tune(model, data="path/to/data.yaml")
+ ```
+
+3. Visualize and analyze model metrics using JupyterLab's plotting capabilities:
+
+ ```python
+ from ultralytics.utils.plotting import plot_results
+
+ plot_results(results.results_dict)
+ ```
+
+4. Experiment with different model architectures and [export formats](../modes/export.md) to find the best balance of speed and accuracy for your specific use case.
+
+JupyterLab's interactive environment allows for quick iterations and real-time feedback, making it easier to optimize your YOLOv8 models efficiently.
+
+### How do I handle common issues when working with JupyterLab and YOLOv8?
+
+When working with JupyterLab and YOLOv8, you might encounter some common issues. Here's how to handle them:
+
+1. GPU memory issues:
+
+ - Use `torch.cuda.empty_cache()` to clear GPU memory between runs.
+ - Adjust batch size or image size to fit your GPU memory.
+
+2. Package conflicts:
+
+ - Create a separate conda environment for your YOLOv8 projects to avoid conflicts.
+ - Use `!pip install package_name` in a notebook cell to install missing packages.
+
+3. Kernel crashes:
+ - Restart the kernel and run cells one by one to identify the problematic code.
diff --git a/docs/en/integrations/kaggle.md b/docs/en/integrations/kaggle.md
index 12b7468158..3ac789cd59 100644
--- a/docs/en/integrations/kaggle.md
+++ b/docs/en/integrations/kaggle.md
@@ -90,3 +90,50 @@ We've seen how Kaggle can boost your YOLOv8 projects by providing free access to
For more details, visit [Kaggle's documentation](https://www.kaggle.com/docs).
Interested in more YOLOv8 integrations? Check out the[ Ultralytics integration guide](https://docs.ultralytics.com/integrations/) to explore additional tools and capabilities for your machine learning projects.
+
+## FAQ
+
+### How do I train a YOLOv8 model on Kaggle?
+
+Training a YOLOv8 model on Kaggle is straightforward. First, access the [Kaggle YOLOv8 Notebook](https://www.kaggle.com/ultralytics/yolov8). Sign in to your Kaggle account, copy and edit the notebook, and select a GPU under the accelerator settings. Run the notebook cells to start training. For more detailed steps, refer to our [YOLOv8 Model Training guide](../modes/train.md).
+
+### What are the benefits of using Kaggle for YOLOv8 model training?
+
+Kaggle offers several advantages for training YOLOv8 models:
+
+- **Free GPU Access**: Utilize powerful GPUs like Nvidia Tesla P100 or T4 x2 for up to 30 hours per week.
+- **Pre-installed Libraries**: Libraries like TensorFlow and PyTorch are pre-installed, simplifying the setup.
+- **Community Collaboration**: Engage with a vast community of data scientists and machine learning enthusiasts.
+- **Version Control**: Easily manage different versions of your notebooks and revert to previous versions if needed.
+
+For more details, visit our [Ultralytics integration guide](https://docs.ultralytics.com/integrations/).
+
+### What common issues might I encounter when using Kaggle for YOLOv8, and how can I resolve them?
+
+Common issues include:
+
+- **Access to GPUs**: Ensure you activate a GPU in your notebook settings. Kaggle allows up to 30 hours of GPU usage per week.
+- **Dataset Licenses**: Check the license of each dataset to understand usage restrictions.
+- **Saving and Committing Notebooks**: Click "Save Version" to save your notebook's state and access output files from the Output tab.
+- **Collaboration**: Kaggle supports asynchronous collaboration; multiple users cannot edit a notebook simultaneously.
+
+For more troubleshooting tips, see our [Common Issues guide](../guides/yolo-common-issues.md).
+
+### Why should I choose Kaggle over other platforms like Google Colab for training YOLOv8 models?
+
+Kaggle offers unique features that make it an excellent choice:
+
+- **Public Notebooks**: Share your work with the community for feedback and collaboration.
+- **Free Access to TPUs**: Speed up training with powerful TPUs without extra costs.
+- **Comprehensive History**: Track changes over time with a detailed history of notebook commits.
+- **Resource Availability**: Significant resources are provided for each notebook session, including 12 hours of execution time for CPU and GPU sessions.
+ For a comparison with Google Colab, refer to our [Google Colab guide](./google-colab.md).
+
+### How can I revert to a previous version of my Kaggle notebook?
+
+To revert to a previous version:
+
+1. Open the notebook and click on the three vertical dots in the top right corner.
+2. Select "View Versions."
+3. Find the version you want to revert to, click on the "..." menu next to it, and select "Revert to Version."
+4. Click "Save Version" to commit the changes.
diff --git a/docs/en/models/index.md b/docs/en/models/index.md
index d848e78920..0b294801d5 100644
--- a/docs/en/models/index.md
+++ b/docs/en/models/index.md
@@ -20,12 +20,13 @@ Here are some of the key models supported:
6. **[YOLOv8](yolov8.md) NEW 🚀**: The latest version of the YOLO family, featuring enhanced capabilities such as instance segmentation, pose/keypoints estimation, and classification.
7. **[YOLOv9](yolov9.md)**: An experimental model trained on the Ultralytics [YOLOv5](yolov5.md) codebase implementing Programmable Gradient Information (PGI).
8. **[YOLOv10](yolov10.md)**: By Tsinghua University, featuring NMS-free training and efficiency-accuracy driven architecture, delivering state-of-the-art performance and latency.
-9. **[Segment Anything Model (SAM)](sam.md)**: Meta's Segment Anything Model (SAM).
-10. **[Mobile Segment Anything Model (MobileSAM)](mobile-sam.md)**: MobileSAM for mobile applications, by Kyung Hee University.
-11. **[Fast Segment Anything Model (FastSAM)](fast-sam.md)**: FastSAM by Image & Video Analysis Group, Institute of Automation, Chinese Academy of Sciences.
-12. **[YOLO-NAS](yolo-nas.md)**: YOLO Neural Architecture Search (NAS) Models.
-13. **[Realtime Detection Transformers (RT-DETR)](rtdetr.md)**: Baidu's PaddlePaddle Realtime Detection Transformer (RT-DETR) models.
-14. **[YOLO-World](yolo-world.md)**: Real-time Open Vocabulary Object Detection models from Tencent AI Lab.
+9. **[Segment Anything Model (SAM)](sam.md)**: Meta's original Segment Anything Model (SAM).
+10. **[Segment Anything Model 2 (SAM2)](sam-2.md)**: The next generation of Meta's Segment Anything Model (SAM) for videos and images.
+11. **[Mobile Segment Anything Model (MobileSAM)](mobile-sam.md)**: MobileSAM for mobile applications, by Kyung Hee University.
+12. **[Fast Segment Anything Model (FastSAM)](fast-sam.md)**: FastSAM by Image & Video Analysis Group, Institute of Automation, Chinese Academy of Sciences.
+13. **[YOLO-NAS](yolo-nas.md)**: YOLO Neural Architecture Search (NAS) Models.
+14. **[Realtime Detection Transformers (RT-DETR)](rtdetr.md)**: Baidu's PaddlePaddle Realtime Detection Transformer (RT-DETR) models.
+15. **[YOLO-World](yolo-world.md)**: Real-time Open Vocabulary Object Detection models from Tencent AI Lab.
diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md
new file mode 100644
index 0000000000..001bd89b60
--- /dev/null
+++ b/docs/en/models/sam-2.md
@@ -0,0 +1,349 @@
+---
+comments: true
+description: Discover SAM 2, the next generation of Meta's Segment Anything Model, supporting real-time promptable segmentation in both images and videos with state-of-the-art performance. Learn about its key features, datasets, and how to use it.
+keywords: SAM 2, Segment Anything, video segmentation, image segmentation, promptable segmentation, zero-shot performance, SA-V dataset, Ultralytics, real-time segmentation, AI, machine learning
+---
+
+# SAM 2: Segment Anything Model 2
+
+!!! Note "🚧 SAM 2 Integration In Progress 🚧"
+
+ The SAM 2 features described in this documentation are currently not enabled in the `ultralytics` package. The Ultralytics team is actively working on integrating SAM 2, and these capabilities should be available soon. We appreciate your patience as we work to implement this exciting new model.
+
+SAM 2, the successor to Meta's [Segment Anything Model (SAM)](sam.md), is a cutting-edge tool designed for comprehensive object segmentation in both images and videos. It excels in handling complex visual data through a unified, promptable model architecture that supports real-time processing and zero-shot generalization.
+
+![SAM 2 Example Results](https://github.com/facebookresearch/segment-anything-2/raw/main/assets/sa_v_dataset.jpg?raw=true)
+
+## Key Features
+
+### Unified Model Architecture
+
+SAM 2 combines the capabilities of image and video segmentation in a single model. This unification simplifies deployment and allows for consistent performance across different media types. It leverages a flexible prompt-based interface, enabling users to specify objects of interest through various prompt types, such as points, bounding boxes, or masks.
+
+### Real-Time Performance
+
+The model achieves real-time inference speeds, processing approximately 44 frames per second. This makes SAM 2 suitable for applications requiring immediate feedback, such as video editing and augmented reality.
+
+### Zero-Shot Generalization
+
+SAM 2 can segment objects it has never encountered before, demonstrating strong zero-shot generalization. This is particularly useful in diverse or evolving visual domains where pre-defined categories may not cover all possible objects.
+
+### Interactive Refinement
+
+Users can iteratively refine the segmentation results by providing additional prompts, allowing for precise control over the output. This interactivity is essential for fine-tuning results in applications like video annotation or medical imaging.
+
+### Advanced Handling of Visual Challenges
+
+SAM 2 includes mechanisms to manage common video segmentation challenges, such as object occlusion and reappearance. It uses a sophisticated memory mechanism to keep track of objects across frames, ensuring continuity even when objects are temporarily obscured or exit and re-enter the scene.
+
+For a deeper understanding of SAM 2's architecture and capabilities, explore the [SAM 2 research paper](https://arxiv.org/abs/2401.12741).
+
+## Performance and Technical Details
+
+SAM 2 sets a new benchmark in the field, outperforming previous models on various metrics:
+
+| Metric | SAM 2 | Previous SOTA |
+| ---------------------------------- | ------------- | ------------- |
+| **Interactive Video Segmentation** | **Best** | - |
+| **Human Interactions Required** | **3x fewer** | Baseline |
+| **Image Segmentation Accuracy** | **Improved** | SAM |
+| **Inference Speed** | **6x faster** | SAM |
+
+## Model Architecture
+
+### Core Components
+
+- **Image and Video Encoder**: Utilizes a transformer-based architecture to extract high-level features from both images and video frames. This component is responsible for understanding the visual content at each timestep.
+- **Prompt Encoder**: Processes user-provided prompts (points, boxes, masks) to guide the segmentation task. This allows SAM 2 to adapt to user input and target specific objects within a scene.
+- **Memory Mechanism**: Includes a memory encoder, memory bank, and memory attention module. These components collectively store and utilize information from past frames, enabling the model to maintain consistent object tracking over time.
+- **Mask Decoder**: Generates the final segmentation masks based on the encoded image features and prompts. In video, it also uses memory context to ensure accurate tracking across frames.
+
+![SAM 2 Architecture Diagram](https://github.com/facebookresearch/segment-anything-2/blob/main/assets/model_diagram.png?raw=true)
+
+### Memory Mechanism and Occlusion Handling
+
+The memory mechanism allows SAM 2 to handle temporal dependencies and occlusions in video data. As objects move and interact, SAM 2 records their features in a memory bank. When an object becomes occluded, the model can rely on this memory to predict its position and appearance when it reappears. The occlusion head specifically handles scenarios where objects are not visible, predicting the likelihood of an object being occluded.
+
+### Multi-Mask Ambiguity Resolution
+
+In situations with ambiguity (e.g., overlapping objects), SAM 2 can generate multiple mask predictions. This feature is crucial for accurately representing complex scenes where a single mask might not sufficiently describe the scene's nuances.
+
+## SA-V Dataset
+
+The SA-V dataset, developed for SAM 2's training, is one of the largest and most diverse video segmentation datasets available. It includes:
+
+- **51,000+ Videos**: Captured across 47 countries, providing a wide range of real-world scenarios.
+- **600,000+ Mask Annotations**: Detailed spatio-temporal mask annotations, referred to as "masklets," covering whole objects and parts.
+- **Dataset Scale**: It features 4.5 times more videos and 53 times more annotations than previous largest datasets, offering unprecedented diversity and complexity.
+
+## Benchmarks
+
+### Video Object Segmentation
+
+SAM 2 has demonstrated superior performance across major video segmentation benchmarks:
+
+| Dataset | J&F | J | F |
+| --------------- | ---- | ---- | ---- |
+| **DAVIS 2017** | 82.5 | 79.8 | 85.2 |
+| **YouTube-VOS** | 81.2 | 78.9 | 83.5 |
+
+### Interactive Segmentation
+
+In interactive segmentation tasks, SAM 2 shows significant efficiency and accuracy:
+
+| Dataset | NoC@90 | AUC |
+| --------------------- | ------ | ----- |
+| **DAVIS Interactive** | 1.54 | 0.872 |
+
+## Installation
+
+To install SAM 2, use the following command. All SAM 2 models will automatically download on first use.
+
+```bash
+pip install ultralytics
+```
+
+## How to Use SAM 2: Versatility in Image and Video Segmentation
+
+!!! Note "🚧 SAM 2 Integration In Progress 🚧"
+
+ The SAM 2 features described in this documentation are currently not enabled in the `ultralytics` package. The Ultralytics team is actively working on integrating SAM 2, and these capabilities should be available soon. We appreciate your patience as we work to implement this exciting new model.
+
+The following table details the available SAM 2 models, their pre-trained weights, supported tasks, and compatibility with different operating modes like [Inference](../modes/predict.md), [Validation](../modes/val.md), [Training](../modes/train.md), and [Export](../modes/export.md).
+
+| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
+| ----------- | ------------------------------------------------------------------------------------- | -------------------------------------------- | --------- | ---------- | -------- | ------ |
+| SAM 2 tiny | [sam2_t.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
+| SAM 2 small | [sam2_s.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
+| SAM 2 base | [sam2_b.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
+| SAM 2 large | [sam2_l.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
+
+### SAM 2 Prediction Examples
+
+SAM 2 can be utilized across a broad spectrum of tasks, including real-time video editing, medical imaging, and autonomous systems. Its ability to segment both static and dynamic visual data makes it a versatile tool for researchers and developers.
+
+#### Segment with Prompts
+
+!!! Example "Segment with Prompts"
+
+ Use prompts to segment specific objects in images or videos.
+
+ === "Python"
+
+ ```python
+ from ultralytics import SAM
+
+ # Load a model
+ model = SAM("sam2_b.pt")
+
+ # Display model information (optional)
+ model.info()
+
+ # Segment with bounding box prompt
+ results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200])
+
+ # Segment with point prompt
+ results = model("path/to/image.jpg", points=[150, 150], labels=[1])
+ ```
+
+#### Segment Everything
+
+!!! Example "Segment Everything"
+
+ Segment the entire image or video content without specific prompts.
+
+ === "Python"
+
+ ```python
+ from ultralytics import SAM
+
+ # Load a model
+ model = SAM("sam2_b.pt")
+
+ # Display model information (optional)
+ model.info()
+
+ # Run inference
+ model("path/to/video.mp4")
+ ```
+
+ === "CLI"
+
+ ```bash
+ # Run inference with a SAM 2 model
+ yolo predict model=sam2_b.pt source=path/to/video.mp4
+ ```
+
+- This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided.
+
+## SAM comparison vs YOLOv8
+
+Here we compare Meta's smallest SAM model, SAM-b, with Ultralytics smallest segmentation model, [YOLOv8n-seg](../tasks/segment.md):
+
+| Model | Size | Parameters | Speed (CPU) |
+| ---------------------------------------------- | -------------------------- | ---------------------- | -------------------------- |
+| Meta's SAM-b | 358 MB | 94.7 M | 51096 ms/im |
+| [MobileSAM](mobile-sam.md) | 40.7 MB | 10.1 M | 46122 ms/im |
+| [FastSAM-s](fast-sam.md) with YOLOv8 backbone | 23.7 MB | 11.8 M | 115 ms/im |
+| Ultralytics [YOLOv8n-seg](../tasks/segment.md) | **6.7 MB** (53.4x smaller) | **3.4 M** (27.9x less) | **59 ms/im** (866x faster) |
+
+This comparison shows the order-of-magnitude differences in the model sizes and speeds between models. Whereas SAM presents unique capabilities for automatic segmenting, it is not a direct competitor to YOLOv8 segment models, which are smaller, faster and more efficient.
+
+Tests run on a 2023 Apple M2 Macbook with 16GB of RAM. To reproduce this test:
+
+!!! Example
+
+ === "Python"
+
+ ```python
+ from ultralytics import SAM, YOLO, FastSAM
+
+ # Profile SAM-b
+ model = SAM("sam_b.pt")
+ model.info()
+ model("ultralytics/assets")
+
+ # Profile MobileSAM
+ model = SAM("mobile_sam.pt")
+ model.info()
+ model("ultralytics/assets")
+
+ # Profile FastSAM-s
+ model = FastSAM("FastSAM-s.pt")
+ model.info()
+ model("ultralytics/assets")
+
+ # Profile YOLOv8n-seg
+ model = YOLO("yolov8n-seg.pt")
+ model.info()
+ model("ultralytics/assets")
+ ```
+
+## Auto-Annotation: Efficient Dataset Creation
+
+Auto-annotation is a powerful feature of SAM 2, enabling users to generate segmentation datasets quickly and accurately by leveraging pre-trained models. This capability is particularly useful for creating large, high-quality datasets without extensive manual effort.
+
+### How to Auto-Annotate with SAM 2
+
+To auto-annotate your dataset using SAM 2, follow this example:
+
+!!! Example "Auto-Annotation Example"
+
+ ```python
+ from ultralytics.data.annotator import auto_annotate
+
+ auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model="sam2_b.pt")
+ ```
+
+| Argument | Type | Description | Default |
+| ------------ | ----------------------- | ------------------------------------------------------------------------------------------------------- | -------------- |
+| `data` | `str` | Path to a folder containing images to be annotated. | |
+| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. | `'yolov8x.pt'` |
+| `sam_model` | `str`, optional | Pre-trained SAM 2 segmentation model. Defaults to 'sam2_b.pt'. | `'sam2_b.pt'` |
+| `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | |
+| `output_dir` | `str`, `None`, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` |
+
+This function facilitates the rapid creation of high-quality segmentation datasets, ideal for researchers and developers aiming to accelerate their projects.
+
+## Limitations
+
+Despite its strengths, SAM 2 has certain limitations:
+
+- **Tracking Stability**: SAM 2 may lose track of objects during extended sequences or significant viewpoint changes.
+- **Object Confusion**: The model can sometimes confuse similar-looking objects, particularly in crowded scenes.
+- **Efficiency with Multiple Objects**: Segmentation efficiency decreases when processing multiple objects simultaneously due to the lack of inter-object communication.
+- **Detail Accuracy**: May miss fine details, especially with fast-moving objects. Additional prompts can partially address this issue, but temporal smoothness is not guaranteed.
+
+## Citations and Acknowledgements
+
+If SAM 2 is a crucial part of your research or development work, please cite it using the following reference:
+
+!!! Quote ""
+
+ === "BibTeX"
+
+ ```bibtex
+ @article{ravi2024sam2,
+ title={SAM 2: Segment Anything in Images and Videos},
+ author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
+ journal={arXiv preprint},
+ year={2024}
+ }
+ ```
+
+We extend our gratitude to Meta AI for their contributions to the AI community with this groundbreaking model and dataset.
+
+## FAQ
+
+### What is SAM 2 and how does it improve upon the original Segment Anything Model (SAM)?
+
+SAM 2, the successor to Meta's [Segment Anything Model (SAM)](sam.md), is a cutting-edge tool designed for comprehensive object segmentation in both images and videos. It excels in handling complex visual data through a unified, promptable model architecture that supports real-time processing and zero-shot generalization. SAM 2 offers several improvements over the original SAM, including:
+
+- **Unified Model Architecture**: Combines image and video segmentation capabilities in a single model.
+- **Real-Time Performance**: Processes approximately 44 frames per second, making it suitable for applications requiring immediate feedback.
+- **Zero-Shot Generalization**: Segments objects it has never encountered before, useful in diverse visual domains.
+- **Interactive Refinement**: Allows users to iteratively refine segmentation results by providing additional prompts.
+- **Advanced Handling of Visual Challenges**: Manages common video segmentation challenges like object occlusion and reappearance.
+
+For more details on SAM 2's architecture and capabilities, explore the [SAM 2 research paper](https://arxiv.org/abs/2401.12741).
+
+### How can I use SAM 2 for real-time video segmentation?
+
+SAM 2 can be utilized for real-time video segmentation by leveraging its promptable interface and real-time inference capabilities. Here's a basic example:
+
+!!! Example "Segment with Prompts"
+
+ Use prompts to segment specific objects in images or videos.
+
+ === "Python"
+
+ ```python
+ from ultralytics import SAM
+
+ # Load a model
+ model = SAM("sam2_b.pt")
+
+ # Display model information (optional)
+ model.info()
+
+ # Segment with bounding box prompt
+ results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200])
+
+ # Segment with point prompt
+ results = model("path/to/image.jpg", points=[150, 150], labels=[1])
+ ```
+
+For more comprehensive usage, refer to the [How to Use SAM 2](#how-to-use-sam-2-versatility-in-image-and-video-segmentation) section.
+
+### What datasets are used to train SAM 2, and how do they enhance its performance?
+
+SAM 2 is trained on the SA-V dataset, one of the largest and most diverse video segmentation datasets available. The SA-V dataset includes:
+
+- **51,000+ Videos**: Captured across 47 countries, providing a wide range of real-world scenarios.
+- **600,000+ Mask Annotations**: Detailed spatio-temporal mask annotations, referred to as "masklets," covering whole objects and parts.
+- **Dataset Scale**: Features 4.5 times more videos and 53 times more annotations than previous largest datasets, offering unprecedented diversity and complexity.
+
+This extensive dataset allows SAM 2 to achieve superior performance across major video segmentation benchmarks and enhances its zero-shot generalization capabilities. For more information, see the [SA-V Dataset](#sa-v-dataset) section.
+
+### How does SAM 2 handle occlusions and object reappearances in video segmentation?
+
+SAM 2 includes a sophisticated memory mechanism to manage temporal dependencies and occlusions in video data. The memory mechanism consists of:
+
+- **Memory Encoder and Memory Bank**: Stores features from past frames.
+- **Memory Attention Module**: Utilizes stored information to maintain consistent object tracking over time.
+- **Occlusion Head**: Specifically handles scenarios where objects are not visible, predicting the likelihood of an object being occluded.
+
+This mechanism ensures continuity even when objects are temporarily obscured or exit and re-enter the scene. For more details, refer to the [Memory Mechanism and Occlusion Handling](#memory-mechanism-and-occlusion-handling) section.
+
+### How does SAM 2 compare to other segmentation models like YOLOv8?
+
+SAM 2 and Ultralytics YOLOv8 serve different purposes and excel in different areas. While SAM 2 is designed for comprehensive object segmentation with advanced features like zero-shot generalization and real-time performance, YOLOv8 is optimized for speed and efficiency in object detection and segmentation tasks. Here's a comparison:
+
+| Model | Size | Parameters | Speed (CPU) |
+| ---------------------------------------------- | -------------------------- | ---------------------- | -------------------------- |
+| Meta's SAM-b | 358 MB | 94.7 M | 51096 ms/im |
+| [MobileSAM](mobile-sam.md) | 40.7 MB | 10.1 M | 46122 ms/im |
+| [FastSAM-s](fast-sam.md) with YOLOv8 backbone | 23.7 MB | 11.8 M | 115 ms/im |
+| Ultralytics [YOLOv8n-seg](../tasks/segment.md) | **6.7 MB** (53.4x smaller) | **3.4 M** (27.9x less) | **59 ms/im** (866x faster) |
+
+For more details, see the [SAM comparison vs YOLOv8](#sam-comparison-vs-yolov8) section.
diff --git a/docs/en/models/sam.md b/docs/en/models/sam.md
index b3301700ba..d7da2be334 100644
--- a/docs/en/models/sam.md
+++ b/docs/en/models/sam.md
@@ -195,13 +195,13 @@ To auto-annotate your dataset with the Ultralytics framework, use the `auto_anno
auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model="sam_b.pt")
```
-| Argument | Type | Description | Default |
-| ---------- | ------------------- | ------------------------------------------------------------------------------------------------------- | ------------ |
-| data | str | Path to a folder containing images to be annotated. | |
-| det_model | str, optional | Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. | 'yolov8x.pt' |
-| sam_model | str, optional | Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. | 'sam_b.pt' |
-| device | str, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | |
-| output_dir | str, None, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | None |
+| Argument | Type | Description | Default |
+| ------------ | --------------------- | ------------------------------------------------------------------------------------------------------- | -------------- |
+| `data` | `str` | Path to a folder containing images to be annotated. | |
+| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. | `'yolov8x.pt'` |
+| `sam_model` | `str`, optional | Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. | `'sam_b.pt'` |
+| `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | |
+| `output_dir` | `str`, None, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` |
The `auto_annotate` function takes the path to your images, with optional arguments for specifying the pre-trained detection and SAM segmentation models, the device to run the models on, and the output directory for saving the annotated results.
@@ -278,7 +278,3 @@ This function takes the path to your images and optional arguments for pre-train
### What datasets are used to train the Segment Anything Model (SAM)?
SAM is trained on the extensive [SA-1B dataset](https://ai.facebook.com/datasets/segment-anything/) which comprises over 1 billion masks across 11 million images. SA-1B is the largest segmentation dataset to date, providing high-quality and diverse training data, ensuring impressive zero-shot performance in varied segmentation tasks. For more details, visit the [Dataset section](#key-features-of-the-segment-anything-model-sam).
-
----
-
-This FAQ aims to address common questions related to the Segment Anything Model (SAM) from Ultralytics, enhancing user understanding and facilitating effective use of Ultralytics products. For additional information, explore the relevant sections linked throughout.
diff --git a/docs/en/models/yolov10.md b/docs/en/models/yolov10.md
index 4646eb5d45..98a164a42e 100644
--- a/docs/en/models/yolov10.md
+++ b/docs/en/models/yolov10.md
@@ -10,6 +10,17 @@ YOLOv10, built on the [Ultralytics](https://ultralytics.com) [Python package](ht
![YOLOv10 consistent dual assignment for NMS-free training](https://github.com/ultralytics/ultralytics/assets/26833433/f9b1bec0-928e-41ce-a205-e12db3c4929a)
+
+
+
+
+ Watch: How to Train YOLOv10 on SKU-110k Dataset using Ultralytics | Retail Dataset
+
+
## Overview
Real-time object detection aims to accurately predict object categories and positions in images with low latency. The YOLO series has been at the forefront of this research due to its balance between performance and efficiency. However, reliance on NMS and architectural inefficiencies have hindered optimal performance. YOLOv10 addresses these issues by introducing consistent dual assignments for NMS-free training and a holistic efficiency-accuracy driven model design strategy.
diff --git a/docs/en/reference/models/sam/modules/decoders.md b/docs/en/reference/models/sam/modules/decoders.md
index ff3aaa9e7a..a224738b70 100644
--- a/docs/en/reference/models/sam/modules/decoders.md
+++ b/docs/en/reference/models/sam/modules/decoders.md
@@ -13,8 +13,4 @@ keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architect
## ::: ultralytics.models.sam.modules.decoders.MaskDecoder
-
diff --git a/docs/en/reference/models/sam/modules/sam.md b/docs/en/reference/models/sam/modules/sam.md
index 5a7c30e42b..a31cece155 100644
--- a/docs/en/reference/models/sam/modules/sam.md
+++ b/docs/en/reference/models/sam/modules/sam.md
@@ -1,6 +1,6 @@
---
-description: Discover the Ultralytics Sam module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
-keywords: Ultralytics, Sam Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
+description: Discover the Ultralytics SAM module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
+keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
---
# Reference for `ultralytics/models/sam/modules/sam.py`
@@ -11,6 +11,6 @@ keywords: Ultralytics, Sam Module, object segmentation, image encoder, mask deco
-## ::: ultralytics.models.sam.modules.sam.Sam
+## ::: ultralytics.models.sam.modules.sam.SAMModel
diff --git a/docs/en/reference/models/sam2/build.md b/docs/en/reference/models/sam2/build.md
new file mode 100644
index 0000000000..94e9f5befb
--- /dev/null
+++ b/docs/en/reference/models/sam2/build.md
@@ -0,0 +1,36 @@
+---
+description: Discover detailed instructions for building various Segment Anything Model 2 (SAM 2) architectures with Ultralytics.
+keywords: Ultralytics, SAM 2 model, Segment Anything Model 2, SAM, model building, deep learning, AI
+---
+
+# Reference for `ultralytics/models/sam2/build.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/build.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.build.build_sam2_t
+
+
diff --git a/docs/en/reference/models/sam2/model.md b/docs/en/reference/models/sam2/model.md
new file mode 100644
index 0000000000..fc0d2e2502
--- /dev/null
+++ b/docs/en/reference/models/sam2/model.md
@@ -0,0 +1,16 @@
+---
+description: Explore the SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities.
+keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset
+---
+
+# Reference for `ultralytics/models/sam2/model.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/model.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.model.SAM2
+
+
diff --git a/docs/en/reference/models/sam2/modules/decoders.md b/docs/en/reference/models/sam2/modules/decoders.md
new file mode 100644
index 0000000000..989169a770
--- /dev/null
+++ b/docs/en/reference/models/sam2/modules/decoders.md
@@ -0,0 +1,16 @@
+---
+description: Explore the MaskDecoder and MLP modules in Ultralytics for efficient mask prediction using transformer architecture. Detailed attributes, functionalities, and implementation.
+keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architecture, mask prediction, neural networks, PyTorch modules
+---
+
+# Reference for `ultralytics/models/sam2/modules/decoders.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/decoders.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.modules.decoders.MaskDecoder
+
+
diff --git a/docs/en/reference/models/sam2/modules/encoders.md b/docs/en/reference/models/sam2/modules/encoders.md
new file mode 100644
index 0000000000..a3da286e1e
--- /dev/null
+++ b/docs/en/reference/models/sam2/modules/encoders.md
@@ -0,0 +1,28 @@
+---
+description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
+keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
+---
+
+# Reference for `ultralytics/models/sam2/modules/encoders.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/encoders.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.modules.encoders.MemoryEncoder
+
+
diff --git a/docs/en/reference/models/sam2/modules/memory_attention.md b/docs/en/reference/models/sam2/modules/memory_attention.md
new file mode 100644
index 0000000000..fef22174ee
--- /dev/null
+++ b/docs/en/reference/models/sam2/modules/memory_attention.md
@@ -0,0 +1,20 @@
+---
+description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository.
+keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention
+---
+
+# Reference for `ultralytics/models/sam2/modules/memory_attention.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/memory_attention.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttentionLayer
+
+
diff --git a/docs/en/reference/models/sam2/modules/sam2.md b/docs/en/reference/models/sam2/modules/sam2.md
new file mode 100644
index 0000000000..fcd47bbdca
--- /dev/null
+++ b/docs/en/reference/models/sam2/modules/sam2.md
@@ -0,0 +1,16 @@
+---
+description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
+keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
+---
+
+# Reference for `ultralytics/models/sam2/modules/sam2.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.modules.sam2.SAM2Model
+
+
diff --git a/docs/en/reference/models/sam2/modules/sam2_blocks.md b/docs/en/reference/models/sam2/modules/sam2_blocks.md
new file mode 100644
index 0000000000..796669a03e
--- /dev/null
+++ b/docs/en/reference/models/sam2/modules/sam2_blocks.md
@@ -0,0 +1,56 @@
+---
+description: Explore detailed documentation of various SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository.
+keywords: Ultralytics, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool
+---
+
+# Reference for `ultralytics/models/sam2/modules/sam2_blocks.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2_blocks.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.modules.sam2_blocks.DropPath
+
+
diff --git a/docs/en/reference/models/sam2/modules/utils.md b/docs/en/reference/models/sam2/modules/utils.md
new file mode 100644
index 0000000000..357cea6237
--- /dev/null
+++ b/docs/en/reference/models/sam2/modules/utils.md
@@ -0,0 +1,44 @@
+---
+description: Explore the detailed API reference for Ultralytics SAM 2 models.
+keywords: Ultralytics, SAM 2, API Reference, models, window partition, data processing, YOLO
+---
+
+# Reference for `ultralytics/models/sam2/modules/utils.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/utils.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.modules.utils.select_closest_cond_frames
+
+
diff --git a/docs/en/reference/models/sam2/predict.md b/docs/en/reference/models/sam2/predict.md
new file mode 100644
index 0000000000..23b30140bc
--- /dev/null
+++ b/docs/en/reference/models/sam2/predict.md
@@ -0,0 +1,16 @@
+---
+description: Explore Ultralytics SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model 2 (SAM 2). Complete implementation details and auxiliary utilities.
+keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference
+---
+
+# Reference for `ultralytics/models/sam2/predict.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/predict.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam2.predict.SAM2Predictor
+
+
diff --git a/mkdocs.yml b/mkdocs.yml
index d849c01e87..fcf5516b55 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -239,6 +239,7 @@ nav:
- YOLOv9: models/yolov9.md
- YOLOv10: models/yolov10.md
- SAM (Segment Anything Model): models/sam.md
+ - SAM 2 (Segment Anything Model 2): models/sam-2.md
- MobileSAM (Mobile Segment Anything Model): models/mobile-sam.md
- FastSAM (Fast Segment Anything Model): models/fast-sam.md
- YOLO-NAS (Neural Architecture Search): models/yolo-nas.md
@@ -508,6 +509,17 @@ nav:
- tiny_encoder: reference/models/sam/modules/tiny_encoder.md
- transformer: reference/models/sam/modules/transformer.md
- predict: reference/models/sam/predict.md
+ - sam2:
+ - build: reference/models/sam2/build.md
+ - model: reference/models/sam2/model.md
+ - modules:
+ - decoders: reference/models/sam2/modules/decoders.md
+ - encoders: reference/models/sam2/modules/encoders.md
+ - memory_attention: reference/models/sam2/modules/memory_attention.md
+ - sam2: reference/models/sam2/modules/sam2.md
+ - sam2_blocks: reference/models/sam2/modules/sam2_blocks.md
+ - utils: reference/models/sam2/modules/utils.md
+ - predict: reference/models/sam2/predict.md
- utils:
- loss: reference/models/utils/loss.md
- ops: reference/models/utils/ops.md
@@ -658,6 +670,7 @@ plugins:
sdk.md: index.md
hub/inference_api.md: hub/inference-api.md
usage/hyperparameter_tuning.md: integrations/ray-tune.md
+ models/sam2.md: models/sam-2.md
reference/base_pred.md: reference/engine/predictor.md
reference/base_trainer.md: reference/engine/trainer.md
reference/exporter.md: reference/engine/exporter.md
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 8c31352897..affb8e35c7 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = "8.2.69"
+__version__ = "8.2.70"
import os
@@ -8,7 +8,7 @@ import os
os.environ["OMP_NUM_THREADS"] = "1" # reduce CPU utilization during training
from ultralytics.data.explorer.explorer import Explorer
-from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
+from ultralytics.models import NAS, RTDETR, SAM, SAM2, YOLO, FastSAM, YOLOWorld
from ultralytics.utils import ASSETS, SETTINGS
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
@@ -21,6 +21,7 @@ __all__ = (
"YOLOWorld",
"NAS",
"SAM",
+ "SAM2",
"FastSAM",
"RTDETR",
"checks",
diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py
index 4f7ad00eef..4a23ec4267 100644
--- a/ultralytics/cfg/__init__.py
+++ b/ultralytics/cfg/__init__.py
@@ -793,6 +793,10 @@ def entrypoint(debug=""):
from ultralytics import FastSAM
model = FastSAM(model)
+ elif "sam2" in stem:
+ from ultralytics import SAM2
+
+ model = SAM2(model)
elif "sam" in stem:
from ultralytics import SAM
diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py
index aff620a9a9..bbade04724 100644
--- a/ultralytics/models/__init__.py
+++ b/ultralytics/models/__init__.py
@@ -4,6 +4,7 @@ from .fastsam import FastSAM
from .nas import NAS
from .rtdetr import RTDETR
from .sam import SAM
+from .sam2 import SAM2
from .yolo import YOLO, YOLOWorld
-__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
+__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import
diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py
index cd9b302384..2a15c3f39f 100644
--- a/ultralytics/models/fastsam/predict.py
+++ b/ultralytics/models/fastsam/predict.py
@@ -21,6 +21,7 @@ class FastSAMPredictor(SegmentationPredictor):
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+ """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
super().__init__(cfg, overrides, _callbacks)
self.prompts = {}
diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py
index cb3a7c6867..253c0159d1 100644
--- a/ultralytics/models/sam/build.py
+++ b/ultralytics/models/sam/build.py
@@ -14,7 +14,7 @@ from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
-from .modules.sam import Sam
+from .modules.sam import SAMModel
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
@@ -105,7 +105,7 @@ def _build_sam(
out_chans=prompt_embed_dim,
)
)
- sam = Sam(
+ sam = SAMModel(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py
index feaeb551a7..a819b6eafd 100644
--- a/ultralytics/models/sam/model.py
+++ b/ultralytics/models/sam/model.py
@@ -44,6 +44,7 @@ class SAM(Model):
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
+ self.is_sam2 = "sam2" in Path(model).stem
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
@@ -54,7 +55,12 @@ class SAM(Model):
weights (str): Path to the weights file.
task (str, optional): Task name. Defaults to None.
"""
- self.model = build_sam(weights)
+ if self.is_sam2:
+ from ..sam2.build import build_sam2
+
+ self.model = build_sam2(weights)
+ else:
+ self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
@@ -112,4 +118,6 @@ class SAM(Model):
Returns:
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
"""
- return {"segment": {"predictor": Predictor}}
+ from ..sam2.predict import SAM2Predictor
+
+ return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py
index 073b1ad40c..eeaab6b453 100644
--- a/ultralytics/models/sam/modules/decoders.py
+++ b/ultralytics/models/sam/modules/decoders.py
@@ -4,9 +4,8 @@ from typing import List, Tuple, Type
import torch
from torch import nn
-from torch.nn import functional as F
-from ultralytics.nn.modules import LayerNorm2d
+from ultralytics.nn.modules import MLP, LayerNorm2d
class MaskDecoder(nn.Module):
@@ -28,7 +27,6 @@ class MaskDecoder(nn.Module):
def __init__(
self,
- *,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
@@ -149,42 +147,3 @@ class MaskDecoder(nn.Module):
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
-
-
-class MLP(nn.Module):
- """
- MLP (Multi-Layer Perceptron) model lightly adapted from
- https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
- """
-
- def __init__(
- self,
- input_dim: int,
- hidden_dim: int,
- output_dim: int,
- num_layers: int,
- sigmoid_output: bool = False,
- ) -> None:
- """
- Initializes the MLP (Multi-Layer Perceptron) model.
-
- Args:
- input_dim (int): The dimensionality of the input features.
- hidden_dim (int): The dimensionality of the hidden layers.
- output_dim (int): The dimensionality of the output layer.
- num_layers (int): The number of hidden layers.
- sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False.
- """
- super().__init__()
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- self.sigmoid_output = sigmoid_output
-
- def forward(self, x):
- """Executes feedforward within the neural network module and applies activation."""
- for i, layer in enumerate(self.layers):
- x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
- if self.sigmoid_output:
- x = torch.sigmoid(x)
- return x
diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py
index a51c34721a..2bf3945dc5 100644
--- a/ultralytics/models/sam/modules/encoders.py
+++ b/ultralytics/models/sam/modules/encoders.py
@@ -211,6 +211,8 @@ class PromptEncoder(nn.Module):
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
@@ -226,8 +228,8 @@ class PromptEncoder(nn.Module):
"""Embeds mask inputs."""
return self.mask_downscaling(masks)
+ @staticmethod
def _get_batch_size(
- self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py
index 95d9bbe662..1617527f07 100644
--- a/ultralytics/models/sam/modules/sam.py
+++ b/ultralytics/models/sam/modules/sam.py
@@ -15,15 +15,14 @@ from .decoders import MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
-class Sam(nn.Module):
+class SAMModel(nn.Module):
"""
- Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
- embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
- decoder to predict object masks.
+ SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate
+ image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by
+ the mask decoder to predict object masks.
Attributes:
mask_threshold (float): Threshold value for mask prediction.
- image_format (str): Format of the input image, default is 'RGB'.
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
@@ -32,7 +31,6 @@ class Sam(nn.Module):
"""
mask_threshold: float = 0.0
- image_format: str = "RGB"
def __init__(
self,
@@ -43,7 +41,7 @@ class Sam(nn.Module):
pixel_std: List[float] = (58.395, 57.12, 57.375),
) -> None:
"""
- Initialize the Sam class to predict object masks from an image and input prompts.
+ Initialize the SAMModel class to predict object masks from an image and input prompts.
Note:
All forward() operations moved to SAMPredictor.
diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py
index db684f8f13..6375c2adc1 100644
--- a/ultralytics/models/sam/modules/transformer.py
+++ b/ultralytics/models/sam/modules/transformer.py
@@ -86,7 +86,6 @@ class TwoWayTransformer(nn.Module):
(torch.Tensor): the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
- bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
@@ -212,6 +211,7 @@ class Attention(nn.Module):
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
+ kv_in_dim: int = None,
) -> None:
"""
Initializes the Attention model with the given dimensions and settings.
@@ -226,13 +226,14 @@ class Attention(nn.Module):
"""
super().__init__()
self.embedding_dim = embedding_dim
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
@staticmethod
diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py
index 41076602cd..e03ba62566 100644
--- a/ultralytics/models/sam/predict.py
+++ b/ultralytics/models/sam/predict.py
@@ -168,7 +168,7 @@ class Predictor(BasePredictor):
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
"""
- features = self.model.image_encoder(im) if self.features is None else self.features
+ features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
@@ -334,7 +334,7 @@ class Predictor(BasePredictor):
"""
device = select_device(self.args.device, verbose=verbose)
if model is None:
- model = build_sam(self.args.model)
+ model = self.get_model()
model.eval()
self.model = model.to(device)
self.device = device
@@ -348,6 +348,10 @@ class Predictor(BasePredictor):
self.model.fp16 = False
self.done_warmup = True
+ def get_model(self):
+ """Built Segment Anything Model (SAM) model."""
+ return build_sam(self.args.model)
+
def postprocess(self, preds, img, orig_imgs):
"""
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
@@ -412,16 +416,18 @@ class Predictor(BasePredictor):
AssertionError: If more than one image is set.
"""
if self.model is None:
- model = build_sam(self.args.model)
- self.setup_model(model)
+ self.setup_model(model=None)
self.setup_source(image)
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
for batch in self.dataset:
im = self.preprocess(batch[1])
- self.features = self.model.image_encoder(im)
- self.im = im
+ self.features = self.get_im_features(im)
break
+ def get_im_features(self, im):
+ """Get image features from the SAM image encoder."""
+ return self.model.image_encoder(im)
+
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts
diff --git a/ultralytics/models/sam2/__init__.py b/ultralytics/models/sam2/__init__.py
new file mode 100644
index 0000000000..755a160af5
--- /dev/null
+++ b/ultralytics/models/sam2/__init__.py
@@ -0,0 +1,6 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from .model import SAM2
+from .predict import SAM2Predictor
+
+__all__ = "SAM2", "SAM2Predictor" # tuple or list
diff --git a/ultralytics/models/sam2/build.py b/ultralytics/models/sam2/build.py
new file mode 100644
index 0000000000..2791029a52
--- /dev/null
+++ b/ultralytics/models/sam2/build.py
@@ -0,0 +1,156 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import torch
+
+from ultralytics.utils.downloads import attempt_download_asset
+
+from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder
+from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
+from .modules.sam2 import SAM2Model
+
+
+def build_sam2_t(checkpoint=None):
+ """Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=96,
+ encoder_stages=[1, 2, 7, 2],
+ encoder_num_heads=1,
+ encoder_global_att_blocks=[5, 7, 9],
+ encoder_window_spec=[8, 4, 14, 7],
+ encoder_backbone_channel_list=[768, 384, 192, 96],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam2_s(checkpoint=None):
+ """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=96,
+ encoder_stages=[1, 2, 11, 2],
+ encoder_num_heads=1,
+ encoder_global_att_blocks=[7, 10, 13],
+ encoder_window_spec=[8, 4, 14, 7],
+ encoder_backbone_channel_list=[768, 384, 192, 96],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam2_b(checkpoint=None):
+ """Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=112,
+ encoder_stages=[2, 3, 16, 3],
+ encoder_num_heads=2,
+ encoder_global_att_blocks=[12, 16, 20],
+ encoder_window_spec=[8, 4, 14, 7],
+ encoder_window_spatial_size=[14, 14],
+ encoder_backbone_channel_list=[896, 448, 224, 112],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam2_l(checkpoint=None):
+ """Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=144,
+ encoder_stages=[2, 6, 36, 4],
+ encoder_num_heads=2,
+ encoder_global_att_blocks=[23, 33, 43],
+ encoder_window_spec=[8, 4, 16, 8],
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
+ checkpoint=checkpoint,
+ )
+
+
+def _build_sam2(
+ encoder_embed_dim=1280,
+ encoder_stages=[2, 6, 36, 4],
+ encoder_num_heads=2,
+ encoder_global_att_blocks=[7, 15, 23, 31],
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
+ encoder_window_spatial_size=[7, 7],
+ encoder_window_spec=[8, 4, 16, 8],
+ checkpoint=None,
+):
+ """Builds a SAM2 model with specified architecture parameters and optional checkpoint loading."""
+ image_encoder = ImageEncoder(
+ trunk=Hiera(
+ embed_dim=encoder_embed_dim,
+ num_heads=encoder_num_heads,
+ stages=encoder_stages,
+ global_att_blocks=encoder_global_att_blocks,
+ window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
+ window_spec=encoder_window_spec,
+ ),
+ neck=FpnNeck(
+ d_model=256,
+ backbone_channel_list=encoder_backbone_channel_list,
+ fpn_top_down_levels=[2, 3],
+ fpn_interp_model="nearest",
+ ),
+ scalp=1,
+ )
+ memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
+ memory_encoder = MemoryEncoder(out_dim=64)
+
+ sam2 = SAM2Model(
+ image_encoder=image_encoder,
+ memory_attention=memory_attention,
+ memory_encoder=memory_encoder,
+ num_maskmem=7,
+ image_size=1024,
+ sigmoid_scale_for_mem_enc=20.0,
+ sigmoid_bias_for_mem_enc=-10.0,
+ use_mask_input_as_output_without_sam=True,
+ directly_add_no_mem_embed=True,
+ use_high_res_features_in_sam=True,
+ multimask_output_in_sam=True,
+ iou_prediction_use_sigmoid=True,
+ use_obj_ptrs_in_encoder=True,
+ add_tpos_enc_to_obj_ptrs=True,
+ only_obj_ptrs_in_the_past_for_eval=True,
+ pred_obj_scores=True,
+ pred_obj_scores_mlp=True,
+ fixed_no_obj_ptr=True,
+ multimask_output_for_tracking=True,
+ use_multimask_token_for_obj_ptr=True,
+ multimask_min_pt_num=0,
+ multimask_max_pt_num=1,
+ use_mlp_for_obj_ptr_proj=True,
+ compile_image_encoder=False,
+ sam_mask_decoder_extra_args=dict(
+ dynamic_multimask_via_stability=True,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ ),
+ )
+
+ if checkpoint is not None:
+ checkpoint = attempt_download_asset(checkpoint)
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)["model"]
+ sam2.load_state_dict(state_dict)
+ sam2.eval()
+ return sam2
+
+
+sam_model_map = {
+ "sam2_t.pt": build_sam2_t,
+ "sam2_s.pt": build_sam2_s,
+ "sam2_b.pt": build_sam2_b,
+ "sam2_l.pt": build_sam2_l,
+}
+
+
+def build_sam2(ckpt="sam_b.pt"):
+ """Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options."""
+ model_builder = None
+ ckpt = str(ckpt) # to allow Path ckpt types
+ for k in sam_model_map.keys():
+ if ckpt.endswith(k):
+ model_builder = sam_model_map.get(k)
+
+ if not model_builder:
+ raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
+
+ return model_builder(ckpt)
diff --git a/ultralytics/models/sam2/model.py b/ultralytics/models/sam2/model.py
new file mode 100644
index 0000000000..4b3265e197
--- /dev/null
+++ b/ultralytics/models/sam2/model.py
@@ -0,0 +1,97 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+"""
+SAM2 model interface.
+
+This module provides an interface to the Segment Anything Model (SAM2) from Ultralytics, designed for real-time image
+segmentation tasks. The SAM2 model allows for promptable segmentation with unparalleled versatility in image analysis,
+and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
+image distributions and tasks without prior knowledge.
+
+Key Features:
+ - Promptable segmentation
+ - Real-time performance
+ - Zero-shot transfer capabilities
+ - Trained on SA-1B dataset
+"""
+
+from ultralytics.models.sam import SAM
+
+from .build import build_sam2
+from .predict import SAM2Predictor
+
+
+class SAM2(SAM):
+ """
+ SAM2 class for real-time image segmentation using the Segment Anything Model (SAM2).
+
+ This class extends the SAM base class, providing an interface to the SAM2 model for promptable segmentation
+ tasks. It supports loading pre-trained weights and offers zero-shot performance capabilities.
+
+ Attributes:
+ model (torch.nn.Module): The loaded SAM2 model.
+ task_map (Dict[str, Type[SAM2Predictor]]): Mapping of 'segment' task to SAM2Predictor.
+
+ Methods:
+ __init__: Initializes the SAM2 model with pre-trained weights.
+ _load: Loads specified weights into the SAM2 model.
+
+ Examples:
+ >>> sam2 = SAM2("sam2_b.pt")
+ >>> sam2._load('path/to/sam2_weights.pt')
+ >>> task_map = sam2.task_map
+ >>> print(task_map)
+ {'segment': SAM2Predictor}
+
+ Notes:
+ - Supports .pt and .pth file extensions for model weights.
+ - Offers zero-shot transfer capabilities for new image distributions and tasks.
+ """
+
+ def __init__(self, model="sam2_b.pt") -> None:
+ """
+ Initializes the SAM2 model with a pre-trained model file.
+
+ Args:
+ model (str): Path to the pre-trained SAM2 model file. File should have a .pt or .pth extension.
+
+ Raises:
+ NotImplementedError: If the model file extension is not .pt or .pth.
+
+ Examples:
+ >>> sam2 = SAM2("sam2_b.pt")
+ """
+ super().__init__(model=model)
+
+ def _load(self, weights: str, task=None):
+ """
+ Loads the specified weights into the SAM2 model.
+
+ This method is responsible for loading pre-trained weights into the SAM2 model. It supports loading
+ weights from files with .pt or .pth extensions.
+
+ Args:
+ weights (str): Path to the weights file. Should be a file with .pt or .pth extension.
+ task (str | None): Task name. If provided, it may be used to configure model-specific settings.
+
+ Examples:
+ >>> sam2_model = SAM2()
+ >>> sam2_model._load('path/to/sam2_weights.pt')
+ """
+ self.model = build_sam2(weights)
+
+ @property
+ def task_map(self):
+ """
+ Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
+
+ Returns:
+ (Dict[str, Type[SAM2Predictor]]): A dictionary mapping the 'segment' task to its corresponding
+ SAM2Predictor class.
+
+ Examples:
+ >>> sam2 = SAM2()
+ >>> task_map = sam2.task_map
+ >>> print(task_map)
+ {'segment': SAM2Predictor}
+ """
+ return {"segment": {"predictor": SAM2Predictor}}
diff --git a/ultralytics/models/sam2/modules/__init__.py b/ultralytics/models/sam2/modules/__init__.py
new file mode 100644
index 0000000000..9e68dc1224
--- /dev/null
+++ b/ultralytics/models/sam2/modules/__init__.py
@@ -0,0 +1 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
diff --git a/ultralytics/models/sam2/modules/decoders.py b/ultralytics/models/sam2/modules/decoders.py
new file mode 100644
index 0000000000..ac6c60a608
--- /dev/null
+++ b/ultralytics/models/sam2/modules/decoders.py
@@ -0,0 +1,305 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from ultralytics.nn.modules import MLP, LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+ """Transformer-based decoder predicting instance segmentation masks from image and prompt embeddings."""
+
+ def __init__(
+ self,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ use_high_res_features: bool = False,
+ iou_prediction_use_sigmoid=False,
+ dynamic_multimask_via_stability=False,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ ) -> None:
+ """
+ Initializes the MaskDecoder module for predicting instance segmentation masks.
+
+ Args:
+ transformer_dim (int): Channel dimension of the transformer.
+ transformer (nn.Module): Transformer used to predict masks.
+ num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
+ activation (Type[nn.Module]): Type of activation to use when upscaling masks.
+ iou_head_depth (int): Depth of the MLP used to predict mask quality.
+ iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
+ use_high_res_features (bool): Whether to use high-resolution features.
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
+ dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
+ dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
+ pred_obj_scores (bool): Whether to predict object scores.
+ pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
+
+ Attributes:
+ transformer_dim (int): Channel dimension of the transformer.
+ transformer (nn.Module): Transformer used to predict masks.
+ num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
+ iou_token (nn.Embedding): Embedding for IOU token.
+ num_mask_tokens (int): Total number of mask tokens.
+ mask_tokens (nn.Embedding): Embedding for mask tokens.
+ pred_obj_scores (bool): Whether to predict object scores.
+ obj_score_token (nn.Embedding): Embedding for object score token.
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
+ output_upscaling (nn.Sequential): Upscaling layers for output.
+ use_high_res_features (bool): Whether to use high-resolution features.
+ conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
+ conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
+ output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
+ iou_prediction_head (MLP): MLP for IOU prediction.
+ pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
+ dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.pred_obj_scores = pred_obj_scores
+ if self.pred_obj_scores:
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ activation(),
+ )
+ self.use_high_res_features = use_high_res_features
+ if use_high_res_features:
+ self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
+
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim,
+ iou_head_hidden_dim,
+ self.num_mask_tokens,
+ iou_head_depth,
+ sigmoid=iou_prediction_use_sigmoid,
+ )
+ if self.pred_obj_scores:
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+ if pred_obj_scores_mlp:
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+
+ # When outputting a single mask, optionally we can dynamically fall back to the best
+ # multimask output token if the single mask output token gives low stability scores.
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predicts masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (torch.Tensor): Embeddings from the image encoder.
+ image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
+ dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
+ multimask_output (bool): Whether to return multiple masks or a single mask.
+ repeat_image (bool): Flag to repeat the image embeddings.
+ high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
+
+ Returns:
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
+ - masks (torch.Tensor): Batched predicted masks.
+ - iou_pred (torch.Tensor): Batched predictions of mask quality.
+ - sam_tokens_out (torch.Tensor): Batched SAM token for mask output.
+
+ Examples:
+ >>> image_embeddings = torch.rand(1, 256, 64, 64)
+ >>> image_pe = torch.rand(1, 256, 64, 64)
+ >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
+ >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
+ >>> decoder = MaskDecoder(256, transformer)
+ >>> masks, iou_pred, sam_tokens_out = decoder.forward(image_embeddings, image_pe,
+ ... sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
+ """
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ repeat_image=repeat_image,
+ high_res_features=high_res_features,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ masks = masks[:, 1:, :, :]
+ iou_pred = iou_pred[:, 1:]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ masks = masks[:, 0:1, :, :]
+ iou_pred = iou_pred[:, 0:1]
+
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
+ else:
+ # Take the mask output token. Here we *always* use the token for single mask output.
+ # At test time, even if we track after 1-click (and using multimask_output=True),
+ # we still take the single mask token here. The rationale is that we always track
+ # after multiple clicks during training, so the past tokens seen during training
+ # are always the single mask token (and we'll let it be the object-memory token).
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
+
+ # Prepare output
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts instance segmentation masks from image and prompt embeddings using a transformer architecture."""
+ # Concatenate output tokens
+ s = 0
+ if self.pred_obj_scores:
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ s = 1
+ else:
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ if repeat_image:
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ else:
+ assert image_embeddings.shape[0] == tokens.shape[0]
+ src = image_embeddings
+ src = src + dense_prompt_embeddings
+ assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, s, :]
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ if not self.use_high_res_features:
+ upscaled_embedding = self.output_upscaling(src)
+ else:
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
+ feat_s0, feat_s1 = high_res_features
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ if self.pred_obj_scores:
+ assert s == 1
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+ else:
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+ return masks, iou_pred, mask_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """Computes mask stability scores based on IoU between upper and lower thresholds."""
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ Dynamically selects the most stable mask output based on stability scores and IoU predictions.
+
+ When outputting a single mask, if the stability score from the current single-mask output (based on output token
+ 0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with
+ the highest predicted IoU score.
+
+ This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+ batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
diff --git a/ultralytics/models/sam2/modules/encoders.py b/ultralytics/models/sam2/modules/encoders.py
new file mode 100644
index 0000000000..b4cd0f8f25
--- /dev/null
+++ b/ultralytics/models/sam2/modules/encoders.py
@@ -0,0 +1,332 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ultralytics.models.sam.modules.encoders import PatchEmbed
+
+from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine
+
+
+class MemoryEncoder(nn.Module):
+ """Encodes pixel features and masks into a memory representation for efficient image segmentation."""
+
+ def __init__(
+ self,
+ out_dim,
+ in_dim=256, # in_dim of pix_feats
+ ):
+ """Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models."""
+ super().__init__()
+
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
+
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+ self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
+ self.out_proj = nn.Identity()
+ if out_dim != in_dim:
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+ def forward(
+ self,
+ pix_feat: torch.Tensor,
+ masks: torch.Tensor,
+ skip_mask_sigmoid: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Processes pixel features and masks, fusing them to generate encoded memory representations."""
+ if not skip_mask_sigmoid:
+ masks = F.sigmoid(masks)
+ masks = self.mask_downsampler(masks)
+
+ # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
+ pix_feat = pix_feat.to(masks.device)
+
+ x = self.pix_feat_proj(pix_feat)
+ x = x + masks
+ x = self.fuser(x)
+ x = self.out_proj(x)
+
+ pos = self.position_encoding(x).to(x.dtype)
+
+ return {"vision_features": x, "vision_pos_enc": [pos]}
+
+
+class ImageEncoder(nn.Module):
+ """Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings."""
+
+ def __init__(
+ self,
+ trunk: nn.Module,
+ neck: nn.Module,
+ scalp: int = 0,
+ ):
+ """Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction."""
+ super().__init__()
+ self.trunk = trunk
+ self.neck = neck
+ self.scalp = scalp
+ assert (
+ self.trunk.channel_list == self.neck.backbone_channel_list
+ ), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
+
+ def forward(self, sample: torch.Tensor):
+ """Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs."""
+ features, pos = self.neck(self.trunk(sample))
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
+
+ src = features[-1]
+ output = {
+ "vision_features": src,
+ "vision_pos_enc": pos,
+ "backbone_fpn": features,
+ }
+ return output
+
+
+class FpnNeck(nn.Module):
+ """Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models."""
+
+ def __init__(
+ self,
+ d_model: int,
+ backbone_channel_list: List[int],
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ fpn_interp_model: str = "bilinear",
+ fuse_type: str = "sum",
+ fpn_top_down_levels: Optional[List[int]] = None,
+ ):
+ """
+ Initializes a modified Feature Pyramid Network (FPN) neck.
+
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
+ similar to ViT positional embedding interpolation.
+
+ Args:
+ d_model (int): Dimension of the model.
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
+ kernel_size (int): Kernel size for the convolutional layers.
+ stride (int): Stride for the convolutional layers.
+ padding (int): Padding for the convolutional layers.
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
+ fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
+
+ Attributes:
+ position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding.
+ convs (nn.ModuleList): List of convolutional layers for each backbone level.
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
+ fuse_type (str): Type of feature fusion.
+ fpn_top_down_levels (List[int]): Levels with top-down feature propagation.
+
+ Examples:
+ >>> backbone_channels = [64, 128, 256, 512]
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
+ >>> print(fpn_neck)
+ """
+ super().__init__()
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
+ self.convs = nn.ModuleList()
+ self.backbone_channel_list = backbone_channel_list
+ for dim in backbone_channel_list:
+ current = nn.Sequential()
+ current.add_module(
+ "conv",
+ nn.Conv2d(
+ in_channels=dim,
+ out_channels=d_model,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ ),
+ )
+
+ self.convs.append(current)
+ self.fpn_interp_model = fpn_interp_model
+ assert fuse_type in ["sum", "avg"]
+ self.fuse_type = fuse_type
+
+ # levels to have top-down features in its outputs
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+ # have top-down propagation, while outputs of level 0 and level 1 have only
+ # lateral features from the same backbone level.
+ if fpn_top_down_levels is None:
+ # default is to have top-down features on all levels
+ fpn_top_down_levels = range(len(self.convs))
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
+
+ def forward(self, xs: List[torch.Tensor]):
+ """
+ Performs forward pass through the Feature Pyramid Network (FPN) neck.
+
+ Args:
+ xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor.
+
+ Returns:
+ (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists:
+ - out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor.
+ - pos: List of positional encodings corresponding to each output feature map.
+
+ Examples:
+ >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
+ >>> outputs, positions = fpn_neck(inputs)
+ """
+ out = [None] * len(self.convs)
+ pos = [None] * len(self.convs)
+ assert len(xs) == len(self.convs)
+ # fpn forward pass
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+ prev_features = None
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ x = xs[i]
+ lateral_features = self.convs[n - i](x)
+ if i in self.fpn_top_down_levels and prev_features is not None:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode=self.fpn_interp_model,
+ align_corners=(None if self.fpn_interp_model == "nearest" else False),
+ antialias=False,
+ )
+ prev_features = lateral_features + top_down_features
+ if self.fuse_type == "avg":
+ prev_features /= 2
+ else:
+ prev_features = lateral_features
+ x_out = prev_features
+ out[i] = x_out
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+
+ return out, pos
+
+
+class Hiera(nn.Module):
+ """Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks."""
+
+ def __init__(
+ self,
+ embed_dim: int = 96, # initial embed dim
+ num_heads: int = 1, # initial number of heads
+ drop_path_rate: float = 0.0, # stochastic depth
+ q_pool: int = 3, # number of q_pool stages
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
+ head_mul: float = 2.0, # head_mul factor at stage shift
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+ # window size per stage, when not using global att.
+ window_spec: Tuple[int, ...] = (
+ 8,
+ 4,
+ 14,
+ 7,
+ ),
+ # global attn in these blocks
+ global_att_blocks: Tuple[int, ...] = (
+ 12,
+ 16,
+ 20,
+ ),
+ return_interm_layers=True, # return feats from every stage
+ ):
+ """Initializes a Hiera model with configurable architecture for hierarchical vision transformers."""
+ super().__init__()
+
+ assert len(stages) == len(window_spec)
+ self.window_spec = window_spec
+
+ depth = sum(stages)
+ self.q_stride = q_stride
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+ self.return_interm_layers = return_interm_layers
+
+ self.patch_embed = PatchEmbed(
+ embed_dim=embed_dim,
+ kernel_size=(7, 7),
+ stride=(4, 4),
+ padding=(3, 3),
+ )
+ # Which blocks have global att?
+ self.global_att_blocks = global_att_blocks
+
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
+ self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ cur_stage = 1
+ self.blocks = nn.ModuleList()
+
+ for i in range(depth):
+ dim_out = embed_dim
+ # lags by a block, so first block of
+ # next stage uses an initial window size
+ # of previous stage and final window size of current stage
+ window_size = self.window_spec[cur_stage - 1]
+
+ if self.global_att_blocks is not None:
+ window_size = 0 if i in self.global_att_blocks else window_size
+
+ if i - 1 in self.stage_ends:
+ dim_out = int(embed_dim * dim_mul)
+ num_heads = int(num_heads * head_mul)
+ cur_stage += 1
+
+ block = MultiScaleBlock(
+ dim=embed_dim,
+ dim_out=dim_out,
+ num_heads=num_heads,
+ drop_path=dpr[i],
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
+ window_size=window_size,
+ )
+
+ embed_dim = dim_out
+ self.blocks.append(block)
+
+ self.channel_list = (
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+ if return_interm_layers
+ else [self.blocks[-1].dim_out]
+ )
+
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional embeddings by interpolating and combining window and background embeddings."""
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """Performs hierarchical vision transformer forward pass, returning multiscale feature maps."""
+ x = self.patch_embed(x)
+ # x: (B, H, W, C)
+
+ # Add pos embed
+ x = x + self._get_pos_embed(x.shape[1:3])
+
+ outputs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
+ feats = x.permute(0, 3, 1, 2)
+ outputs.append(feats)
+
+ return outputs
diff --git a/ultralytics/models/sam2/modules/memory_attention.py b/ultralytics/models/sam2/modules/memory_attention.py
new file mode 100644
index 0000000000..8b5673c314
--- /dev/null
+++ b/ultralytics/models/sam2/modules/memory_attention.py
@@ -0,0 +1,170 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import copy
+from typing import Optional
+
+import torch
+from torch import Tensor, nn
+
+from .sam2_blocks import RoPEAttention
+
+
+class MemoryAttentionLayer(nn.Module):
+ """Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks."""
+
+ def __init__(
+ self,
+ d_model: int = 256,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ pos_enc_at_attn: bool = False,
+ pos_enc_at_cross_attn_keys: bool = True,
+ pos_enc_at_cross_attn_queries: bool = False,
+ ):
+ """Initializes a MemoryAttentionLayer with self-attention, cross-attention, and feedforward components."""
+ super().__init__()
+ self.d_model = d_model
+ self.dim_feedforward = dim_feedforward
+ self.dropout_value = dropout
+ self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
+ self.cross_attn_image = RoPEAttention(
+ rope_k_repeat=True,
+ embedding_dim=256,
+ num_heads=1,
+ downsample_rate=1,
+ kv_in_dim=64,
+ )
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = nn.ReLU()
+
+ # Where to add pos enc
+ self.pos_enc_at_attn = pos_enc_at_attn
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+
+ def _forward_sa(self, tgt, query_pos):
+ """Performs self-attention on input tensor using positional encoding and RoPE attention mechanism."""
+ tgt2 = self.norm1(tgt)
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+ tgt2 = self.self_attn(q, k, v=tgt2)
+ tgt = tgt + self.dropout1(tgt2)
+ return tgt
+
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
+ """Performs cross-attention between target and memory tensors using RoPEAttention mechanism."""
+ kwds = {}
+ if num_k_exclude_rope > 0:
+ assert isinstance(self.cross_attn_image, RoPEAttention)
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
+
+ # Cross-Attention
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.cross_attn_image(
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+ v=memory,
+ **kwds,
+ )
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ num_k_exclude_rope: int = 0,
+ ) -> torch.Tensor:
+ """Performs self-attention, cross-attention, and MLP operations on input tensors for memory-based attention."""
+ tgt = self._forward_sa(tgt, query_pos)
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
+ # MLP
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+
+class MemoryAttention(nn.Module):
+ """Memory attention module for processing sequential data with self and cross-attention mechanisms."""
+
+ def __init__(
+ self,
+ d_model: int,
+ pos_enc_at_input: bool,
+ layer: nn.Module,
+ num_layers: int,
+ batch_first: bool = True, # Do layers expect batch first input?
+ ):
+ """Initializes MemoryAttention module with layers and normalization for attention processing."""
+ super().__init__()
+ self.d_model = d_model
+ self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
+ self.num_layers = num_layers
+ self.norm = nn.LayerNorm(d_model)
+ self.pos_enc_at_input = pos_enc_at_input
+ self.batch_first = batch_first
+
+ def forward(
+ self,
+ curr: torch.Tensor, # self-attention inputs
+ memory: torch.Tensor, # cross-attention inputs
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
+ ):
+ """Applies self-attention and cross-attention to input tensors, processing through multiple layers."""
+ if isinstance(curr, list):
+ assert isinstance(curr_pos, list)
+ assert len(curr) == len(curr_pos) == 1
+ curr, curr_pos = (
+ curr[0],
+ curr_pos[0],
+ )
+
+ assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
+
+ output = curr
+ if self.pos_enc_at_input and curr_pos is not None:
+ output = output + 0.1 * curr_pos
+
+ if self.batch_first:
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+ memory = memory.transpose(0, 1)
+ memory_pos = memory_pos.transpose(0, 1)
+
+ for layer in self.layers:
+ kwds = {}
+ if isinstance(layer.cross_attn_image, RoPEAttention):
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
+
+ output = layer(
+ tgt=output,
+ memory=memory,
+ pos=memory_pos,
+ query_pos=curr_pos,
+ **kwds,
+ )
+ normed_output = self.norm(output)
+
+ if self.batch_first:
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+
+ return normed_output
diff --git a/ultralytics/models/sam2/modules/sam2.py b/ultralytics/models/sam2/modules/sam2.py
new file mode 100644
index 0000000000..363241a19d
--- /dev/null
+++ b/ultralytics/models/sam2/modules/sam2.py
@@ -0,0 +1,804 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import torch
+import torch.distributed
+import torch.nn.functional as F
+from torch.nn.init import trunc_normal_
+
+from ultralytics.models.sam.modules.encoders import PromptEncoder
+from ultralytics.nn.modules import MLP
+
+from .decoders import MaskDecoder
+from .sam2_blocks import TwoWayTransformer
+from .utils import get_1d_sine_pe, select_closest_cond_frames
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+class SAM2Model(torch.nn.Module):
+ """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities."""
+
+ mask_threshold: float = 0.0
+
+ def __init__(
+ self,
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ num_maskmem=7, # default 1 input frame + 6 previous frames
+ image_size=512,
+ backbone_stride=16, # stride of the image backbone output
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
+ binarize_mask_from_pts_for_mem_enc=False,
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
+ max_cond_frames_in_attn=-1,
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
+ # (instead of using the transformer encoder)
+ directly_add_no_mem_embed=False,
+ # whether to use high-resolution feature maps in the SAM mask decoder
+ use_high_res_features_in_sam=False,
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
+ multimask_output_in_sam=False,
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
+ multimask_min_pt_num=1,
+ multimask_max_pt_num=1,
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
+ multimask_output_for_tracking=False,
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
+ use_multimask_token_for_obj_ptr: bool = False,
+ # whether to use sigmoid to restrict ious prediction to [0-1]
+ iou_prediction_use_sigmoid=False,
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
+ memory_temporal_stride_for_eval=1,
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
+ add_all_frames_to_correct_as_cond=False,
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
+ non_overlap_masks_for_mem_enc=False,
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder=False,
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
+ max_obj_ptrs_in_encoder=16,
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
+ add_tpos_enc_to_obj_ptrs=True,
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
+ proj_tpos_enc_in_obj_ptrs=False,
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
+ only_obj_ptrs_in_the_past_for_eval=False,
+ # Whether to predict if there is an object in the frame
+ pred_obj_scores: bool = False,
+ # Whether to use an MLP to predict object scores
+ pred_obj_scores_mlp: bool = False,
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
+ # Whether to have a fixed no obj pointer when there is no object present
+ # or to use it as an additive embedding with obj_ptr produced by decoder
+ fixed_no_obj_ptr: bool = False,
+ # Soft no object, i.e. mix in no_obj_ptr softly,
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
+ soft_no_obj_ptr: bool = False,
+ use_mlp_for_obj_ptr_proj: bool = False,
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
+ sam_mask_decoder_extra_args=None,
+ compile_image_encoder: bool = False,
+ ):
+ """Initializes SAM2Model model with image encoder, memory attention, and memory encoder components."""
+ super().__init__()
+
+ # Part 1: the image backbone
+ self.image_encoder = image_encoder
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+ if use_obj_ptrs_in_encoder:
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+ if proj_tpos_enc_in_obj_ptrs:
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+
+ # Part 2: memory attention to condition current frame's visual features
+ # with memories (and obj ptrs) from past frames
+ self.memory_attention = memory_attention
+ self.hidden_dim = memory_attention.d_model
+
+ # Part 3: memory encoder for the previous frame's outputs
+ self.memory_encoder = memory_encoder
+ self.mem_dim = self.hidden_dim
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
+ # if there is compression of memories along channel dim
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+ self.num_maskmem = num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+ # a single token to indicate no memory embedding from previous frames
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ trunc_normal_(self.no_mem_embed, std=0.02)
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
+ # Apply sigmoid to the output raw mask logits (to turn them from
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+ # On frames with mask input, whether to directly output the input mask without
+ # using a SAM prompt encoder + mask decoder
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+ # and SAM-style mask decoder for the final mask output
+ self.image_size = image_size
+ self.backbone_stride = backbone_stride
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+ self.pred_obj_scores = pred_obj_scores
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
+ self.soft_no_obj_ptr = soft_no_obj_ptr
+ if self.fixed_no_obj_ptr:
+ assert self.pred_obj_scores
+ assert self.use_obj_ptrs_in_encoder
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ trunc_normal_(self.no_obj_ptr, std=0.02)
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+
+ self._build_sam_heads()
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
+
+ # Model compilation
+ if compile_image_encoder:
+ # Compile the forward function (not the full module) to allow loading checkpoints.
+ print("Image encoder compilation is enabled. First forward pass will be slow.")
+ self.image_encoder.forward = torch.compile(
+ self.image_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ )
+
+ @property
+ def device(self):
+ """Returns the device on which the model's parameters are stored."""
+ return next(self.parameters()).device
+
+ def forward(self, *args, **kwargs):
+ """Processes input frames and prompts to generate object masks and scores in video sequences."""
+ raise NotImplementedError(
+ "Please use the corresponding methods in SAM2VideoPredictor for inference."
+ "See notebooks/video_predictor_example.ipynb for an example."
+ )
+
+ def _build_sam_heads(self):
+ """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
+ self.sam_prompt_embed_dim = self.hidden_dim
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
+
+ # build PromptEncoder and MaskDecoder from SAM
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
+ self.sam_prompt_encoder = PromptEncoder(
+ embed_dim=self.sam_prompt_embed_dim,
+ image_embedding_size=(
+ self.sam_image_embedding_size,
+ self.sam_image_embedding_size,
+ ),
+ input_image_size=(self.image_size, self.image_size),
+ mask_in_chans=16,
+ )
+ self.sam_mask_decoder = MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=self.sam_prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=self.sam_prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ use_high_res_features=self.use_high_res_features_in_sam,
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+ pred_obj_scores=self.pred_obj_scores,
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+ **(self.sam_mask_decoder_extra_args or {}),
+ )
+ if self.use_obj_ptrs_in_encoder:
+ # a linear projection on SAM output tokens to turn them into object pointers
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+ if self.use_mlp_for_obj_ptr_proj:
+ self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
+ else:
+ self.obj_ptr_proj = torch.nn.Identity()
+ if self.proj_tpos_enc_in_obj_ptrs:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
+
+ def _forward_sam_heads(
+ self,
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ ):
+ """
+ Forward SAM prompt encoders and mask heads.
+
+ Args:
+ backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
+ point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
+ 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
+ pixel-unit coordinates in (x, y) format for P input points.
+ 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
+ 0 means negative clicks, and -1 means padding.
+ mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
+ same spatial size as the image.
+ high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
+ (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
+ for SAM decoder.
+ multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
+ output only 1 mask and its IoU estimate.
+
+ Returns:
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
+ low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
+ high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
+ ious: Tensor of shape (B, M) with estimated IoU for each output mask.
+ low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
+ high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
+ obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
+ object_score_logits: Tensor of shape (B,) with object score logits.
+
+ Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
+
+ Examples:
+ >>> backbone_features = torch.rand(1, 256, 32, 32)
+ >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
+ >>> mask_inputs = torch.rand(1, 1, 512, 512)
+ >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
+ >>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
+ """
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ else:
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ else:
+ sam_mask_prompt = mask_inputs
+ else:
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ )
+ (
+ low_res_multimasks,
+ ious,
+ sam_output_tokens,
+ object_score_logits,
+ ) = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ high_res_features=high_res_features,
+ )
+ if self.pred_obj_scores:
+ is_obj_appearing = object_score_logits > 0
+
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ sam_output_token = sam_output_tokens[:, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ if self.pred_obj_scores:
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ # Only hard possible with gt
+ assert not self.teacher_force_obj_scores_for_mem
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ else:
+ lambda_is_obj_appearing = is_obj_appearing.float()
+
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+ """Processes mask inputs to generate output mask logits and object pointers without using SAM."""
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.float()
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks,
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ # a dummy IoU prediction of all 1's under mask input
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+ if not self.use_obj_ptrs_in_encoder:
+ # all zeros as a dummy object pointer (of shape [B, C])
+ obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
+ else:
+ # produce an object pointer using the SAM decoder from the mask input
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+ backbone_features=backbone_features,
+ mask_inputs=self.mask_downsample(mask_inputs_float),
+ high_res_features=high_res_features,
+ )
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.float()
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ if self.pred_obj_scores:
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_masks,
+ high_res_masks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """Process image batch through encoder to extract multi-level features for SAM model."""
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
+ return backbone_out
+
+ def _prepare_backbone_features(self, backbone_out):
+ """Prepare and flatten visual features from the image backbone output."""
+ backbone_out = backbone_out.copy()
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+ # flatten NxCxHxW to HWxNxC
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+ def _prepare_memory_conditioned_features(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ ):
+ """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ device = current_vision_feats[-1].device
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+ # In this case, we skip the fusion with any memory.
+ if self.num_maskmem == 0: # Disable memory and skip fusion
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat
+
+ num_obj_ptr_tokens = 0
+ # Step 1: condition the visual features of the current frame on previous memories
+ if not is_init_cond_frame:
+ # Retrieve the memories encoded with the maskmem backbone
+ to_cat_memory, to_cat_memory_pos_embed = [], []
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
+ # when getting temporal positional embedding below)
+ assert len(output_dict["cond_frame_outputs"]) > 0
+ # Select a maximum number of temporally closest cond frames for cross attention
+ cond_outputs = output_dict["cond_frame_outputs"]
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
+ )
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
+ r = self.memory_temporal_stride_for_eval
+ for t_pos in range(1, self.num_maskmem):
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
+ if t_rel == 1:
+ # for t_rel == 1, we take the last frame (regardless of r)
+ if not track_in_reverse:
+ # the frame immediately before this frame (i.e. frame_idx - 1)
+ prev_frame_idx = frame_idx - t_rel
+ else:
+ # the frame immediately after this frame (i.e. frame_idx + 1)
+ prev_frame_idx = frame_idx + t_rel
+ else:
+ # for t_rel >= 2, we take the memory frame from every r-th frames
+ if not track_in_reverse:
+ # first find the nearest frame among every r-th frames before this frame
+ # for r=1, this would be (frame_idx - 2)
+ prev_frame_idx = ((frame_idx - 2) // r) * r
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
+ else:
+ # first find the nearest frame among every r-th frames after this frame
+ # for r=1, this would be (frame_idx + 2)
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+ if out is None:
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+ # frames, we still attend to it as if it's a non-conditioning frame.
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
+ t_pos_and_prevs.append((t_pos, out))
+
+ for t_pos, prev in t_pos_and_prevs:
+ if prev is None:
+ continue # skip padding frames
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
+ feats = prev["maskmem_features"].cuda(non_blocking=True)
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
+ maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+ # Temporal positional encoding
+ maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+ to_cat_memory_pos_embed.append(maskmem_enc)
+
+ # Construct the list of past object pointers
+ if self.use_obj_ptrs_in_encoder:
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+ # First add those object pointers from selected conditioning frames
+ # (optionally, only include object pointers in the past during evaluation)
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+ ptr_cond_outputs = {
+ t: out
+ for t, out in selected_cond_outputs.items()
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+ }
+ else:
+ ptr_cond_outputs = selected_cond_outputs
+ pos_and_ptrs = [
+ # Temporal pos encoding contains how far away each pointer is from current frame
+ (abs(frame_idx - t), out["obj_ptr"])
+ for t, out in ptr_cond_outputs.items()
+ ]
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+ if t < 0 or (num_frames is not None and t >= num_frames):
+ break
+ out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
+ if out is not None:
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+ # If we have at least one object pointer, add them to the across attention
+ if len(pos_and_ptrs) > 0:
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
+ # a temporal positional embedding based on how far each object pointer is from
+ # the current frame (sine embedding normalized by the max pointer num).
+ if self.add_tpos_enc_to_obj_ptrs:
+ t_diff_max = max_obj_ptrs_in_encoder - 1
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+ obj_pos = torch.tensor(pos_list, device=device)
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+ else:
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+ if self.mem_dim < C:
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+ obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+ to_cat_memory.append(obj_ptrs)
+ to_cat_memory_pos_embed.append(obj_pos)
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
+ else:
+ num_obj_ptr_tokens = 0
+ else:
+ # for initial conditioning frames, encode them without using any previous memory
+ if self.directly_add_no_mem_embed:
+ # directly add no-mem embedding (instead of using the transformer encoder)
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+
+ # Step 2: Concatenate the memories and forward through the transformer encoder
+ memory = torch.cat(to_cat_memory, dim=0)
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+
+ pix_feat_with_mem = self.memory_attention(
+ curr=current_vision_feats,
+ curr_pos=current_vision_pos_embeds,
+ memory=memory,
+ memory_pos=memory_pos_embed,
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
+ )
+ # reshape the output (HW)BC => BCHW
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ pred_masks_high_res,
+ is_mask_from_pts,
+ ):
+ """Encodes the current frame's features and predicted masks into a new memory representation."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(
+ pix_feat,
+ mask_for_mem,
+ skip_mask_sigmoid=True, # sigmoid already applied
+ )
+ maskmem_features = maskmem_out["vision_features"]
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+
+ return maskmem_features, maskmem_pos_enc
+
+ def track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
+ # in demo we might call `track_step` multiple times for each user click,
+ # and only encode the memory when the user finalizes their clicks. And in ablation
+ # settings like SAM training on static images, we don't need the memory encoder.
+ run_mem_encoder=True,
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+ prev_sam_mask_logits=None,
+ ):
+ """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats[-1:],
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+ feat_sizes=feat_sizes[-1:],
+ output_dict=output_dict,
+ num_frames=num_frames,
+ track_in_reverse=track_in_reverse,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ assert point_inputs is not None and mask_inputs is None
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._forward_sam_heads(
+ backbone_features=pix_feat_with_mem,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ high_res_features=high_res_features,
+ multimask_output=multimask_output,
+ )
+ (
+ _,
+ _,
+ _,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ _,
+ ) = sam_outputs
+
+ current_out["pred_masks"] = low_res_masks
+ current_out["pred_masks_high_res"] = high_res_masks
+ current_out["obj_ptr"] = obj_ptr
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (that can be used in future frames)
+ if run_mem_encoder and self.num_maskmem > 0:
+ high_res_masks_for_mem_enc = high_res_masks
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks_for_mem_enc,
+ is_mask_from_pts=(point_inputs is not None),
+ )
+ current_out["maskmem_features"] = maskmem_features
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ current_out["maskmem_features"] = None
+ current_out["maskmem_pos_enc"] = None
+
+ return current_out
+
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
+ """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+ multimask_output = (
+ self.multimask_output_in_sam
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _apply_non_overlapping_constraints(self, pred_masks):
+ """Applies non-overlapping constraints to object masks, keeping highest scoring object at each location."""
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks
diff --git a/ultralytics/models/sam2/modules/sam2_blocks.py b/ultralytics/models/sam2/modules/sam2_blocks.py
new file mode 100644
index 0000000000..67dc587e95
--- /dev/null
+++ b/ultralytics/models/sam2/modules/sam2_blocks.py
@@ -0,0 +1,715 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import copy
+import math
+from functools import partial
+from typing import Optional, Tuple, Type, Union
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from ultralytics.models.sam.modules.transformer import (
+ Attention,
+)
+from ultralytics.models.sam.modules.transformer import (
+ TwoWayAttentionBlock as SAMTwoWayAttentionBlock,
+)
+from ultralytics.models.sam.modules.transformer import (
+ TwoWayTransformer as SAMTwoWayTransformer,
+)
+from ultralytics.nn.modules import MLP, LayerNorm2d
+
+from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
+
+
+class DropPath(nn.Module):
+ """Implements stochastic depth regularization for neural networks during training."""
+
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
+ """Initialize DropPath module with specified drop probability and scaling option."""
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ """Applies stochastic depth to input tensor during training, with optional scaling."""
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and self.scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class MaskDownSampler(nn.Module):
+ """Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing."""
+
+ def __init__(
+ self,
+ embed_dim=256,
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ total_stride=16,
+ activation=nn.GELU,
+ ):
+ """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
+ super().__init__()
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
+ assert stride**num_layers == total_stride
+ self.encoder = nn.Sequential()
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (stride**2)
+ self.encoder.append(
+ nn.Conv2d(
+ mask_in_chans,
+ mask_out_chans,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+ )
+ self.encoder.append(LayerNorm2d(mask_out_chans))
+ self.encoder.append(activation())
+ mask_in_chans = mask_out_chans
+
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+
+ def forward(self, x):
+ """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
+ return self.encoder(x)
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class CXBlock(nn.Module):
+ """
+ ConvNeXt Block for efficient feature extraction in convolutional neural networks.
+
+ This block implements a modified version of the ConvNeXt architecture, offering two equivalent
+ implementations for improved performance and flexibility.
+
+ Attributes:
+ dwconv (nn.Conv2d): Depthwise convolution layer.
+ norm (LayerNorm2d): Layer normalization applied to channels.
+ pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
+ act (nn.GELU): GELU activation function.
+ pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
+ gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
+ drop_path (nn.Module): DropPath layer for stochastic depth regularization.
+
+ Methods:
+ forward: Processes the input tensor through the ConvNeXt block.
+
+ Examples:
+ >>> import torch
+ >>> x = torch.randn(1, 64, 56, 56)
+ >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
+ >>> output = block(x)
+ >>> print(output.shape)
+ torch.Size([1, 64, 56, 56])
+ """
+
+ def __init__(
+ self,
+ dim,
+ kernel_size=7,
+ padding=3,
+ drop_path=0.0,
+ layer_scale_init_value=1e-6,
+ use_dwconv=True,
+ ):
+ """
+ Initialize a ConvNeXt Block.
+
+ This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization,
+ pointwise convolutions, and GELU activation.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int): Size of the convolutional kernel. Default is 7.
+ padding (int): Padding size for the convolution. Default is 3.
+ drop_path (float): Stochastic depth rate. Default is 0.0.
+ layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6.
+ use_dwconv (bool): Whether to use depthwise convolution. Default is True.
+
+ Attributes:
+ dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
+ norm (LayerNorm2d): Layer normalization applied to the output of dwconv.
+ pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
+ act (nn.GELU): GELU activation function.
+ pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
+ gamma (nn.Parameter | None): Learnable scale parameter for the residual path.
+
+ Examples:
+ >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
+ >>> x = torch.randn(1, 64, 32, 32)
+ >>> output = block(x)
+ >>> print(output.shape)
+ torch.Size([1, 64, 32, 32])
+ """
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim if use_dwconv else 1,
+ ) # depthwise conv
+ self.norm = LayerNorm2d(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x):
+ """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
+ input = x
+ x = self.dwconv(x)
+ x = self.norm(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class Fuser(nn.Module):
+ """
+ A module for fusing features through multiple layers of a neural network.
+
+ This class applies a series of identical layers to an input tensor, optionally projecting the input first.
+
+ Attributes:
+ proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
+ layers (nn.ModuleList): A list of identical layers to be applied sequentially.
+
+ Methods:
+ forward: Applies the fuser to an input tensor.
+
+ Examples:
+ >>> layer = CXBlock(dim=256)
+ >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
+ >>> x = torch.randn(1, 256, 32, 32)
+ >>> output = fuser(x)
+ >>> print(output.shape)
+ torch.Size([1, 256, 32, 32])
+ """
+
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
+ """
+ Initializes the Fuser module.
+
+ This module creates a sequence of identical layers and optionally applies an input projection.
+
+ Args:
+ layer (nn.Module): The layer to be replicated in the fuser.
+ num_layers (int): The number of times to replicate the layer.
+ dim (int | None): The dimension for input projection, if used.
+ input_projection (bool): Whether to use input projection.
+
+ Attributes:
+ proj (nn.Module): The input projection layer, or nn.Identity if not used.
+ layers (nn.ModuleList): A list of replicated layers.
+
+ Examples:
+ >>> layer = nn.Linear(64, 64)
+ >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
+ >>> input_tensor = torch.randn(1, 64)
+ >>> output = fuser(input_tensor)
+ """
+ super().__init__()
+ self.proj = nn.Identity()
+ self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
+
+ if input_projection:
+ assert dim is not None
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+
+ def forward(self, x):
+ """Applies a series of layers to the input tensor, optionally projecting it first."""
+ x = self.proj(x)
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class TwoWayAttentionBlock(SAMTwoWayAttentionBlock):
+ """
+ A two-way attention block for performing self-attention and cross-attention in both directions.
+
+ This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on
+ sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
+ cross-attention from dense to sparse inputs.
+
+ Attributes:
+ self_attn (Attention): Self-attention layer for queries.
+ norm1 (nn.LayerNorm): Layer normalization after the first attention block.
+ cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
+ norm2 (nn.LayerNorm): Layer normalization after the second attention block.
+ mlp (MLP): MLP block for transforming query embeddings.
+ norm3 (nn.LayerNorm): Layer normalization after the MLP block.
+ norm4 (nn.LayerNorm): Layer normalization after the third attention block.
+ cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
+ skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
+
+ Methods:
+ forward: Processes input through the attention blocks and MLP.
+
+ Examples:
+ >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
+ >>> sparse_input = torch.randn(1, 100, 256)
+ >>> dense_input = torch.randn(1, 256, 16, 16)
+ >>> sparse_output, dense_output = block(sparse_input, dense_input)
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
+
+ This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs
+ to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.
+
+ Args:
+ embedding_dim (int): The channel dimension of the embeddings.
+ num_heads (int): The number of heads in the attention layers.
+ mlp_dim (int): The hidden dimension of the MLP block.
+ activation (Type[nn.Module]): The activation function of the MLP block.
+ attention_downsample_rate (int): The downsample rate for attention computations.
+ skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
+
+ Attributes:
+ self_attn (Attention): The self-attention layer for the queries.
+ norm1 (nn.LayerNorm): Layer normalization following the first attention block.
+ cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
+ norm2 (nn.LayerNorm): Layer normalization following the second attention block.
+ mlp (MLP): MLP block that transforms the query embeddings.
+ norm3 (nn.LayerNorm): Layer normalization following the MLP block.
+ norm4 (nn.LayerNorm): Layer normalization following the third attention block.
+ cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
+ skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
+
+ Examples:
+ >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> sparse_inputs = torch.randn(1, 100, 256)
+ >>> dense_inputs = torch.randn(1, 256, 32, 32)
+ >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
+ """
+ super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
+ self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
+
+
+class TwoWayTransformer(SAMTwoWayTransformer):
+ """
+ A Two-Way Transformer module for simultaneous attention to image and query points.
+
+ This class implements a specialized transformer decoder that attends to an input image using queries with
+ supplied positional embeddings. It is particularly useful for tasks like object detection, image
+ segmentation, and point cloud processing.
+
+ Attributes:
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for input embeddings.
+ num_heads (int): Number of heads for multihead attention.
+ mlp_dim (int): Internal channel dimension for the MLP block.
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+
+ Methods:
+ forward: Processes input image embeddings and query embeddings through the transformer.
+
+ Examples:
+ >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> image_embedding = torch.randn(1, 256, 64, 64)
+ >>> query_embedding = torch.randn(1, 100, 256)
+ >>> output = transformer(image_embedding, query_embedding)
+ """
+
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ Initializes a TwoWayTransformer instance.
+
+ This transformer decoder attends to an input image using queries with supplied positional embeddings.
+ It is designed for tasks like object detection, image segmentation, and point cloud processing.
+
+ Args:
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for the input embeddings.
+ num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
+ mlp_dim (int): Channel dimension internal to the MLP block.
+ activation (Type[nn.Module]): Activation function to use in the MLP block.
+ attention_downsample_rate (int): Downsampling rate for attention computations.
+
+ Attributes:
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for the input embeddings.
+ num_heads (int): Number of heads for multihead attention.
+ mlp_dim (int): Internal channel dimension for the MLP block.
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries.
+
+ Examples:
+ >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> transformer
+ TwoWayTransformer(
+ (layers): ModuleList(
+ (0-4): 5 x TwoWayAttentionBlock(...)
+ )
+ (final_attn_token_to_image): Attention(...)
+ (norm_final_attn): LayerNorm(...)
+ )
+ """
+ super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
+ self.layers = nn.ModuleList()
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+
+class RoPEAttention(Attention):
+ """Implements rotary position encoding for attention mechanisms in transformer architectures."""
+
+ def __init__(
+ self,
+ *args,
+ rope_theta=10000.0,
+ # whether to repeat q rope to match k length
+ # this is needed for cross-attention to memories
+ rope_k_repeat=False,
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
+ **kwargs,
+ ):
+ """Initializes RoPEAttention with rotary position encoding for attention mechanisms."""
+ super().__init__(*args, **kwargs)
+
+ self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+ self.freqs_cis = freqs_cis
+ self.rope_k_repeat = rope_k_repeat
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
+ """Applies rotary position encoding and computes attention between query, key, and value tensors."""
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Apply rotary position encoding
+ w = h = math.sqrt(q.shape[-2])
+ self.freqs_cis = self.freqs_cis.to(q.device)
+ if self.freqs_cis.shape[0] != q.shape[-2]:
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+ if q.shape[-2] != k.shape[-2]:
+ assert self.rope_k_repeat
+
+ num_k_rope = k.size(-2) - num_k_exclude_rope
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
+ q,
+ k[:, :, :num_k_rope],
+ freqs_cis=self.freqs_cis,
+ repeat_freqs_k=self.rope_k_repeat,
+ )
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
+
+
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+ """Applies pooling and optional normalization to a tensor, handling permutations for spatial operations."""
+ if pool is None:
+ return x
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = pool(x)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ if norm:
+ x = norm(x)
+
+ return x
+
+
+class MultiScaleAttention(nn.Module):
+ """Implements multi-scale self-attention with optional query pooling for efficient feature extraction."""
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ q_pool: nn.Module = None,
+ ):
+ """Initializes a multi-scale attention module with configurable query pooling and linear projections."""
+ super().__init__()
+
+ self.dim = dim
+ self.dim_out = dim_out
+
+ self.num_heads = num_heads
+ head_dim = dim_out // num_heads
+ self.scale = head_dim**-0.5
+
+ self.q_pool = q_pool
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Applies multi-scale attention to input tensor, optionally downsampling query features."""
+ B, H, W, _ = x.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ q, k, v = torch.unbind(qkv, 2)
+
+ # Q pooling (for downsample at stage changes)
+ if self.q_pool:
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+ H, W = q.shape[1:3] # downsampled shape
+ q = q.reshape(B, H * W, self.num_heads, -1)
+
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+ x = F.scaled_dot_product_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ )
+ # Transpose back
+ x = x.transpose(1, 2)
+ x = x.reshape(B, H, W, -1)
+
+ x = self.proj(x)
+
+ return x
+
+
+class MultiScaleBlock(nn.Module):
+ """Multiscale attention block with window partitioning and query pooling for efficient vision transformers."""
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ drop_path: float = 0.0,
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
+ q_stride: Tuple[int, int] = None,
+ act_layer: nn.Module = nn.GELU,
+ window_size: int = 0,
+ ):
+ """Initializes a multi-scale attention block with optional window partitioning and downsampling."""
+ super().__init__()
+
+ if isinstance(norm_layer, str):
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.norm1 = norm_layer(dim)
+
+ self.window_size = window_size
+
+ self.pool, self.q_stride = None, q_stride
+ if self.q_stride:
+ self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
+
+ self.attn = MultiScaleAttention(
+ dim,
+ dim_out,
+ num_heads=num_heads,
+ q_pool=self.pool,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim_out)
+ self.mlp = MLP(
+ dim_out,
+ int(dim_out * mlp_ratio),
+ dim_out,
+ num_layers=2,
+ act=act_layer,
+ )
+
+ if dim != dim_out:
+ self.proj = nn.Linear(dim, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Applies multi-scale attention and MLP processing to input tensor, with optional windowing."""
+ shortcut = x # B, H, W, C
+ x = self.norm1(x)
+
+ # Skip connection
+ if self.dim != self.dim_out:
+ shortcut = do_pool(self.proj(x), self.pool)
+
+ # Window partition
+ window_size = self.window_size
+ if window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, window_size)
+
+ # Window Attention + Q Pooling (if stage change)
+ x = self.attn(x)
+ if self.q_stride:
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.q_stride[0]
+ H, W = shortcut.shape[1:3]
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
+
+ x = shortcut + self.drop_path(x)
+ # MLP
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PositionEmbeddingSine(nn.Module):
+ """Generates sinusoidal positional embeddings for 2D inputs like images."""
+
+ def __init__(
+ self,
+ num_pos_feats,
+ temperature: int = 10000,
+ normalize: bool = True,
+ scale: Optional[float] = None,
+ ):
+ """Initializes sinusoidal position embeddings for 2D image inputs."""
+ super().__init__()
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
+ self.num_pos_feats = num_pos_feats // 2
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ self.cache = {}
+
+ def _encode_xy(self, x, y):
+ """Encodes 2D positions using sine and cosine functions for positional embeddings."""
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
+ x_embed = x * self.scale
+ y_embed = y * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, None] / dim_t
+ pos_y = y_embed[:, None] / dim_t
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
+ return pos_x, pos_y
+
+ @torch.no_grad()
+ def encode_boxes(self, x, y, w, h):
+ """Encodes box coordinates and dimensions into positional embeddings for object detection tasks."""
+ pos_x, pos_y = self._encode_xy(x, y)
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+ return pos
+
+ encode = encode_boxes # Backwards compatibility
+
+ @torch.no_grad()
+ def encode_points(self, x, y, labels):
+ """Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels."""
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+ assert bx == by and nx == ny and bx == bl and nx == nl
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+ return pos
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor):
+ """Generate sinusoidal position embeddings for 2D inputs."""
+ cache_key = (x.shape[-2], x.shape[-1])
+ if cache_key in self.cache:
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+ y_embed = (
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+ .view(1, -1, 1)
+ .repeat(x.shape[0], 1, x.shape[-1])
+ )
+ x_embed = (
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+ .view(1, 1, -1)
+ .repeat(x.shape[0], x.shape[-2], 1)
+ )
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ self.cache[cache_key] = pos[0]
+ return pos
diff --git a/ultralytics/models/sam2/modules/utils.py b/ultralytics/models/sam2/modules/utils.py
new file mode 100644
index 0000000000..b09dd9b2f4
--- /dev/null
+++ b/ultralytics/models/sam2/modules/utils.py
@@ -0,0 +1,191 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import torch
+import torch.nn.functional as F
+
+
+def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
+ """
+ Selects the closest conditioning frames to a given frame index.
+
+ Args:
+ frame_idx (int): Current frame index.
+ cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
+ max_cond_frame_num (int): Maximum number of conditioning frames to select.
+
+ Returns:
+ (Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries:
+ - selected_outputs: Selected items from cond_frame_outputs.
+ - unselected_outputs: Items not selected from cond_frame_outputs.
+
+ Examples:
+ >>> frame_idx = 5
+ >>> cond_frame_outputs = {1: 'a', 3: 'b', 7: 'c', 9: 'd'}
+ >>> max_cond_frame_num = 2
+ >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
+ >>> print(selected)
+ {3: 'b', 7: 'c'}
+ >>> print(unselected)
+ {1: 'a', 9: 'd'}
+ """
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
+ selected_outputs = cond_frame_outputs
+ unselected_outputs = {}
+ else:
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
+ selected_outputs = {}
+
+ # the closest conditioning frame before `frame_idx` (if any)
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
+ if idx_before is not None:
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
+
+ # the closest conditioning frame after `frame_idx` (if any)
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
+ if idx_after is not None:
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
+
+ # add other temporally closest conditioning frames until reaching a total
+ # of `max_cond_frame_num` conditioning frames.
+ num_remain = max_cond_frame_num - len(selected_outputs)
+ inds_remain = sorted(
+ (t for t in cond_frame_outputs if t not in selected_outputs),
+ key=lambda x: abs(x - frame_idx),
+ )[:num_remain]
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
+ unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
+
+ return selected_outputs, unselected_outputs
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """Generates 1D sinusoidal positional embeddings for given positions and dimensions."""
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+ return pos_embed
+
+
+def init_t_xy(end_x: int, end_y: int):
+ """Initializes 1D and 2D coordinate tensors for a grid of size end_x by end_y."""
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
+ t_x = (t % end_x).float()
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
+ return t_x, t_y
+
+
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+ """Computes axial complex exponential positional encodings for 2D spatial positions."""
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ t_x, t_y = init_t_xy(end_x, end_y)
+ freqs_x = torch.outer(t_x, freqs_x)
+ freqs_y = torch.outer(t_y, freqs_y)
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+ """Reshapes frequency tensor for broadcasting, ensuring compatibility with input tensor dimensions."""
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_enc(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ repeat_freqs_k: bool = False,
+):
+ """Applies rotary positional encoding to query and key tensors using complex-valued frequency components."""
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ if xk_ is None:
+ # no keys to rotate, due to dropout
+ return xq_out.type_as(xq).to(xq.device), xk
+ # repeat freqs along seq_len dim to match k seq_len
+ if repeat_freqs_k:
+ r = xk_.shape[-2] // xq_.shape[-2]
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
+
+
+def window_partition(x, window_size):
+ """
+ Partitions input tensor into non-overlapping windows with padding if needed.
+
+ Args:
+ x (torch.Tensor): Input tensor with shape (B, H, W, C).
+ window_size (int): Size of each window.
+
+ Returns:
+ (Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing:
+ - windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
+ - (Hp, Wp) (Tuple[int, int]): Padded height and width before partition.
+
+ Examples:
+ >>> x = torch.randn(1, 16, 16, 3)
+ >>> windows, (Hp, Wp) = window_partition(x, window_size=4)
+ >>> print(windows.shape, Hp, Wp)
+ torch.Size([16, 4, 4, 3]) 16 16
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(windows, window_size, pad_hw, hw):
+ """
+ Unpartitions windowed sequences into original sequences and removes padding.
+
+ This function reverses the windowing process, reconstructing the original input from windowed segments
+ and removing any padding that was added during the windowing process.
+
+ Args:
+ windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
+ window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
+ the size of each window, and C is the number of channels.
+ window_size (int): Size of each window.
+ pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
+ hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
+
+ Returns:
+ (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
+ are the original height and width, and C is the number of channels.
+
+ Examples:
+ >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
+ >>> pad_hw = (16, 16) # Padded height and width
+ >>> hw = (15, 14) # Original height and width
+ >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
+ >>> print(x.shape)
+ torch.Size([1, 15, 14, 64])
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
diff --git a/ultralytics/models/sam2/predict.py b/ultralytics/models/sam2/predict.py
new file mode 100644
index 0000000000..ca9438a2b7
--- /dev/null
+++ b/ultralytics/models/sam2/predict.py
@@ -0,0 +1,182 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import torch
+
+from ..sam.predict import Predictor
+from .build import build_sam2
+
+
+class SAM2Predictor(Predictor):
+ """
+ A predictor class for the Segment Anything Model 2 (SAM2), extending the base Predictor class.
+
+ This class provides an interface for model inference tailored to image segmentation tasks, leveraging SAM2's
+ advanced architecture and promptable segmentation capabilities. It facilitates flexible and real-time mask
+ generation, working with various types of prompts such as bounding boxes, points, and low-resolution masks.
+
+ Attributes:
+ cfg (Dict): Configuration dictionary specifying model and task-related parameters.
+ overrides (Dict): Dictionary containing values that override the default configuration.
+ _callbacks (Dict): Dictionary of user-defined callback functions to augment behavior.
+ args (namespace): Namespace to hold command-line arguments or other operational variables.
+ im (torch.Tensor): Preprocessed input image tensor.
+ features (torch.Tensor): Extracted image features used for inference.
+ prompts (Dict): Collection of various prompt types, such as bounding boxes and points.
+ segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
+ model (torch.nn.Module): The loaded SAM2 model.
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
+ _bb_feat_sizes (List[Tuple[int, int]]): List of feature sizes for different backbone levels.
+
+ Methods:
+ get_model: Builds and returns the SAM2 model.
+ prompt_inference: Performs image segmentation inference based on various prompts.
+ set_image: Preprocesses and sets a single image for inference.
+ get_im_features: Extracts image features from the SAM2 image encoder.
+
+ Examples:
+ >>> predictor = SAM2Predictor(model='sam2_l.pt')
+ >>> predictor.set_image('path/to/image.jpg')
+ >>> masks, scores = predictor.prompt_inference(im=predictor.im, points=[[500, 375]], labels=[1])
+ >>> print(f"Generated {len(masks)} mask(s) with scores: {scores}")
+ """
+
+ _bb_feat_sizes = [
+ (256, 256),
+ (128, 128),
+ (64, 64),
+ ]
+
+ def get_model(self):
+ """Retrieves and initializes the Segment Anything Model (SAM) for image segmentation tasks."""
+ return build_sam2(self.args.model)
+
+ def prompt_inference(
+ self,
+ im,
+ bboxes=None,
+ points=None,
+ labels=None,
+ masks=None,
+ multimask_output=False,
+ img_idx=-1,
+ ):
+ """
+ Performs image segmentation inference based on various prompts using SAM2 architecture.
+
+ Args:
+ im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
+ bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
+ points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
+ labels (np.ndarray | List | None): Labels for point prompts with shape (N,). 1 = foreground, 0 = background.
+ masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
+ img_idx (int): Index of the image in the batch to process.
+
+ Returns:
+ (tuple): Tuple containing:
+ - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
+ - np.ndarray: Quality scores for each mask, with length C.
+ - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
+
+ Examples:
+ >>> predictor = SAM2Predictor(cfg)
+ >>> image = torch.rand(1, 3, 640, 640)
+ >>> bboxes = [[100, 100, 200, 200]]
+ >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
+ """
+ features = self.get_im_features(im) if self.features is None else self.features
+
+ src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
+ r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
+ # Transform input prompts
+ if points is not None:
+ points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
+ points = points[None] if points.ndim == 1 else points
+ # Assuming labels are all positive if users don't pass labels.
+ if labels is None:
+ labels = torch.ones(points.shape[0])
+ labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
+ points *= r
+ # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
+ points, labels = points[:, None], labels[:, None]
+ if bboxes is not None:
+ bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
+ bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
+ bboxes *= r
+ if masks is not None:
+ masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
+
+ points = (points, labels) if points is not None else None
+ # TODO: Embed prompts
+ # if bboxes is not None:
+ # box_coords = bboxes.reshape(-1, 2, 2)
+ # box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bboxes.device)
+ # box_labels = box_labels.repeat(bboxes.size(0), 1)
+ # # we merge "boxes" and "points" into a single "concat_points" input (where
+ # # boxes are added at the beginning) to sam_prompt_encoder
+ # if concat_points is not None:
+ # concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
+ # concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
+ # concat_points = (concat_coords, concat_labels)
+ # else:
+ # concat_points = (box_coords, box_labels)
+
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+ points=points,
+ boxes=bboxes,
+ masks=masks,
+ )
+ # Predict masks
+ batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
+ pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
+ image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=batched_mode,
+ high_res_features=high_res_features,
+ )
+ # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
+ # `d` could be 1 or 3 depends on `multimask_output`.
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
+
+ def set_image(self, image):
+ """
+ Preprocesses and sets a single image for inference.
+
+ This function sets up the model if not already initialized, configures the data source to the specified image,
+ and preprocesses the image for feature extraction. Only one image can be set at a time.
+
+ Args:
+ image (str | np.ndarray): Image file path as a string, or a numpy array image read by cv2.
+
+ Raises:
+ AssertionError: If more than one image is set.
+
+ Examples:
+ >>> predictor = SAM2Predictor()
+ >>> predictor.set_image("path/to/image.jpg")
+ >>> predictor.set_image(np.array([...])) # Using a numpy array
+ """
+ if self.model is None:
+ self.setup_model(model=None)
+ self.setup_source(image)
+ assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
+ for batch in self.dataset:
+ im = self.preprocess(batch[1])
+ self.features = self.get_im_features(im)
+ break
+
+ def get_im_features(self, im):
+ """Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks."""
+ backbone_out = self.model.forward_image(im)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+ feats = [
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py
index 062c6094ea..4c170aa317 100644
--- a/ultralytics/nn/modules/transformer.py
+++ b/ultralytics/nn/modules/transformer.py
@@ -174,18 +174,20 @@ class MLPBlock(nn.Module):
class MLP(nn.Module):
"""Implements a simple multi-layer perceptron (also called FFN)."""
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=nn.ReLU, sigmoid=False):
"""Initialize the MLP with specified input, hidden, output dimensions and number of layers."""
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.sigmoid = sigmoid
+ self.act = act()
def forward(self, x):
"""Forward pass for the entire MLP."""
for i, layer in enumerate(self.layers):
- x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
- return x
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x.sigmoid() if self.sigmoid else x
class LayerNorm2d(nn.Module):
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index d220430386..624167694f 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-
+import contextlib
import gc
import math
import os
@@ -101,12 +101,15 @@ def autocast(enabled: bool, device: str = "cuda"):
def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
- import cpuinfo # pip install py-cpuinfo
+ with contextlib.suppress(Exception):
+ import cpuinfo # pip install py-cpuinfo
+
+ k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference (not all keys always available)
+ info = cpuinfo.get_cpu_info() # info dict
+ string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
+ return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
- k = "brand_raw", "hardware_raw", "arch_string_raw" # info keys sorted by preference (not all keys always available)
- info = cpuinfo.get_cpu_info() # info dict
- string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
- return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
+ return "unknown"
def select_device(device="", batch=0, newline=False, verbose=True):