Merge branch 'main' into model_yaml_args

model_yaml_args
Glenn Jocher 1 year ago committed by GitHub
commit 8628f7345a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      docs/models/index.md
  2. 99
      docs/models/mobile-sam.md
  3. 94
      docs/models/sam.md
  4. 9
      docs/reference/vit/sam/autosize.md
  5. 5
      docs/reference/vit/sam/build.md
  6. 9
      docs/reference/vit/sam/modules/mask_generator.md
  7. 9
      docs/reference/vit/sam/modules/prompt_predictor.md
  8. 59
      docs/reference/vit/sam/modules/tiny_encoder.md
  9. 9
      docs/reference/yolo/fastsam/model.md
  10. 9
      docs/reference/yolo/fastsam/predict.md
  11. 9
      docs/reference/yolo/fastsam/prompt.md
  12. 14
      docs/reference/yolo/fastsam/utils.md
  13. 9
      docs/reference/yolo/fastsam/val.md
  14. 5
      docs/reference/yolo/utils/ops.md
  15. 11
      mkdocs.yml
  16. 2
      ultralytics/__init__.py
  17. 9
      ultralytics/vit/sam/__init__.py
  18. 94
      ultralytics/vit/sam/autosize.py
  19. 48
      ultralytics/vit/sam/build.py
  20. 12
      ultralytics/vit/sam/model.py
  21. 353
      ultralytics/vit/sam/modules/mask_generator.py
  22. 242
      ultralytics/vit/sam/modules/prompt_predictor.py
  23. 653
      ultralytics/vit/sam/modules/tiny_encoder.py
  24. 382
      ultralytics/vit/sam/predict.py
  25. 24
      ultralytics/yolo/data/annotator.py
  26. 8
      ultralytics/yolo/data/augment.py
  27. 17
      ultralytics/yolo/engine/predictor.py
  28. 10
      ultralytics/yolo/engine/results.py
  29. 14
      ultralytics/yolo/fastsam/utils.py
  30. 3
      ultralytics/yolo/utils/downloads.py
  31. 1
      ultralytics/yolo/utils/instance.py
  32. 44
      ultralytics/yolo/utils/ops.py

@ -17,6 +17,7 @@ In this documentation, we provide information on four major models:
5. [YOLOv7](./yolov7.md): Updated YOLO models released in 2022 by the authors of YOLOv4. 5. [YOLOv7](./yolov7.md): Updated YOLO models released in 2022 by the authors of YOLOv4.
6. [YOLOv8](./yolov8.md): The latest version of the YOLO family, featuring enhanced capabilities such as instance segmentation, pose/keypoints estimation, and classification. 6. [YOLOv8](./yolov8.md): The latest version of the YOLO family, featuring enhanced capabilities such as instance segmentation, pose/keypoints estimation, and classification.
7. [Segment Anything Model (SAM)](./sam.md): Meta's Segment Anything Model (SAM). 7. [Segment Anything Model (SAM)](./sam.md): Meta's Segment Anything Model (SAM).
7. [Mobile Segment Anything Model (MobileSAM)](./mobile-sam.md): MobileSAM for mobile applications by Kyung Hee University.
8. [Fast Segment Anything Model (FastSAM)](./fast-sam.md): FastSAM by Image & Video Analysis Group, Institute of Automation, Chinese Academy of Sciences. 8. [Fast Segment Anything Model (FastSAM)](./fast-sam.md): FastSAM by Image & Video Analysis Group, Institute of Automation, Chinese Academy of Sciences.
9. [YOLO-NAS](./yolo-nas.md): YOLO Neural Architecture Search (NAS) Models. 9. [YOLO-NAS](./yolo-nas.md): YOLO Neural Architecture Search (NAS) Models.
10. [Realtime Detection Transformers (RT-DETR)](./rtdetr.md): Baidu's PaddlePaddle Realtime Detection Transformer (RT-DETR) models. 10. [Realtime Detection Transformers (RT-DETR)](./rtdetr.md): Baidu's PaddlePaddle Realtime Detection Transformer (RT-DETR) models.

@ -0,0 +1,99 @@
---
comments: true
description: MobileSAM is a lightweight adaptation of the Segment Anything Model (SAM) designed for mobile applications. It maintains the full functionality of the original SAM while significantly improving speed, making it suitable for CPU-only edge devices, such as mobile phones.
keywords: MobileSAM, Faster Segment Anything, Segment Anything, Segment Anything Model, SAM, Meta SAM, image segmentation, promptable segmentation, zero-shot performance, SA-1B dataset, advanced architecture, auto-annotation, Ultralytics, pre-trained models, SAM base, SAM large, instance segmentation, computer vision, AI, artificial intelligence, machine learning, data annotation, segmentation masks, detection model, YOLO detection model, bibtex, Meta AI
---
![MobileSAM Logo](https://github.com/ChaoningZhang/MobileSAM/blob/master/assets/logo2.png?raw=true)
# Faster Segment Anything (MobileSAM)
The MobileSAM paper is now available on [ResearchGate](https://www.researchgate.net/publication/371851844_Faster_Segment_Anything_Towards_Lightweight_SAM_for_Mobile_Applications) and [arXiv](https://arxiv.org/pdf/2306.14289.pdf). The most recent version will initially appear on ResearchGate due to the delayed content update on arXiv.
A demonstration of MobileSAM running on a CPU can be accessed at this [demo link](https://huggingface.co/spaces/dhkim2810/MobileSAM). The performance on a Mac i5 CPU takes approximately 3 seconds. On the Hugging Face demo, the interface and lower-performance CPUs contribute to a slower response, but it continues to function effectively.
MobileSAM is implemented in various projects including [Grounding-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything), [AnyLabeling](https://github.com/vietanhdev/anylabeling), and [SegmentAnythingin3D](https://github.com/Jumpat/SegmentAnythingin3D).
MobileSAM is trained on a single GPU with a 100k dataset (1% of the original images) in less than a day. The code for this training will be made available in the future.
## Adapting from SAM to MobileSAM
Since MobileSAM retains the same pipeline as the original SAM, we have incorporated the original's pre-processing, post-processing, and all other interfaces. Consequently, those currently using the original SAM can transition to MobileSAM with minimal effort.
MobileSAM performs comparably to the original SAM and retains the same pipeline except for a change in the image encoder. Specifically, we replace the original heavyweight ViT-H encoder (632M) with a smaller Tiny-ViT (5M). On a single GPU, MobileSAM operates at about 12ms per image: 8ms on the image encoder and 4ms on the mask decoder.
The following table provides a comparison of ViT-based image encoders:
| Image Encoder | Original SAM | MobileSAM |
|---------------|--------------|-----------|
| Parameters | 611M | 5M |
| Speed | 452ms | 8ms |
Both the original SAM and MobileSAM utilize the same prompt-guided mask decoder:
| Mask Decoder | Original SAM | MobileSAM |
|--------------|--------------|-----------|
| Parameters | 3.876M | 3.876M |
| Speed | 4ms | 4ms |
Here is the comparison of the whole pipeline:
| Whole Pipeline (Enc+Dec) | Original SAM | MobileSAM |
|--------------------------|--------------|-----------|
| Parameters | 615M | 9.66M |
| Speed | 456ms | 12ms |
The performance of MobileSAM and the original SAM are demonstrated using both a point and a box as prompts.
![Image with Point as Prompt](https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/assets/mask_box.jpg?raw=true)
![Image with Box as Prompt](https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/assets/mask_box.jpg?raw=true)
With its superior performance, MobileSAM is approximately 5 times smaller and 7 times faster than the current FastSAM. More details are available at the [MobileSAM project page](https://github.com/ChaoningZhang/MobileSAM).
## Testing MobileSAM in Ultralytics
Just like the original SAM, we offer a straightforward testing method in Ultralytics, including modes for both Point and Box prompts.
### Model Download
You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt).
### Point Prompt
```python
from ultralytics import SAM
# Load the model
model = SAM('mobile_sam.pt')
# Predict a segment based on a point prompt
model.predict('ultralytics/assets/zidane.jpg', points=[900, 370], labels=[1])
```
### Box Prompt
```python
from ultralytics import SAM
# Load the model
model = SAM('mobile_sam.pt')
# Predict a segment based on a box prompt
model.predict('ultralytics/assets/zidane.jpg', bboxes=[439, 437, 524, 709])
```
We have implemented `MobileSAM` and `SAM` using the same API. For more usage information, please see the [SAM page](./sam.md).
### Citing MobileSAM
If you find MobileSAM useful in your research or development work, please consider citing our paper:
```bibtex
@article{mobile_sam,
title={Faster Segment Anything: Towards Lightweight SAM for Mobile Applications},
author={Zhang, Chaoning and Han, Dongshen and Qiao, Yu and Kim, Jung Uk and Bae, Sung Ho and Lee, Seungkyu and Hong, Choong Seon},
journal={arXiv preprint arXiv:2306.14289},
year={2023}
}
```

@ -30,9 +30,33 @@ For an in-depth look at the Segment Anything Model and the SA-1B dataset, please
The Segment Anything Model can be employed for a multitude of downstream tasks that go beyond its training data. This includes edge detection, object proposal generation, instance segmentation, and preliminary text-to-mask prediction. With prompt engineering, SAM can swiftly adapt to new tasks and data distributions in a zero-shot manner, establishing it as a versatile and potent tool for all your image segmentation needs. The Segment Anything Model can be employed for a multitude of downstream tasks that go beyond its training data. This includes edge detection, object proposal generation, instance segmentation, and preliminary text-to-mask prediction. With prompt engineering, SAM can swiftly adapt to new tasks and data distributions in a zero-shot manner, establishing it as a versatile and potent tool for all your image segmentation needs.
!!! example "SAM prediction example" ### SAM prediction example
Device is determined automatically. If a GPU is available then it will be used, otherwise inference will run on CPU. !!! example "Segment with prompts"
Segment image with given prompts.
=== "Python"
```python
from ultralytics import SAM
# Load a model
model = SAM('sam_b.pt')
# Display model information (optional)
model.info()
# Run inference with bboxes prompt
model('ultralytics/assets/zidane.jpg', bboxes=[439, 437, 524, 709])
# Run inference with points prompt
model.predict('ultralytics/assets/zidane.jpg', points=[900, 370], labels=[1])
```
!!! example "Segment everything"
Segment the whole image.
=== "Python" === "Python"
@ -45,7 +69,7 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
# Display model information (optional) # Display model information (optional)
model.info() model.info()
# Run inference with the model # Run inference
model('path/to/image.jpg') model('path/to/image.jpg')
``` ```
=== "CLI" === "CLI"
@ -55,6 +79,48 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
yolo predict model=sam_b.pt source=path/to/image.jpg yolo predict model=sam_b.pt source=path/to/image.jpg
``` ```
- The logic here is to segment the whole image if you don't pass any prompts(bboxes/points/masks).
!!! example "SAMPredictor example"
This way you can set image once and run prompts inference multiple times without running image encoder multiple times.
=== "Prompt inference"
```python
from ultralytics.vit.sam import Predictor as SAMPredictor
# Create SAMPredictor
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024, model="mobile_sam.pt")
predictor = SAMPredictor(overrides=overrides)
# Set image
predictor.set_image("ultralytics/assets/zidane.jpg") # set with image file
predictor.set_image(cv2.imread("ultralytics/assets/zidane.jpg")) # set with np.ndarray
results = predictor(bboxes=[439, 437, 524, 709])
results = predictor(points=[900, 370], labels=[1])
# Reset image
predictor.reset_image()
```
Segment everything with additional args.
=== "Segment everything"
```python
from ultralytics.vit.sam import Predictor as SAMPredictor
# Create SAMPredictor
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024, model="mobile_sam.pt")
predictor = SAMPredictor(overrides=overrides)
# segment with additional args
results = predictor(source="ultralytics/assets/zidane.jpg", crop_n_layers=1, points_stride=64)
```
- More additional args for `Segment everything` see [`Predictor/generate` Reference](../reference/vit/sam/predict.md).
## Available Models and Supported Tasks ## Available Models and Supported Tasks
| Model Type | Pre-trained Weights | Tasks Supported | | Model Type | Pre-trained Weights | Tasks Supported |
@ -76,21 +142,33 @@ Here we compare Meta's smallest SAM model, SAM-b, with Ultralytics smallest segm
| Model | Size | Parameters | Speed (CPU) | | Model | Size | Parameters | Speed (CPU) |
|------------------------------------------------|----------------------------|------------------------|-------------------------| |------------------------------------------------|----------------------------|------------------------|-------------------------|
| Meta's SAM-b | 358 MB | 94.7 M | 51096 ms | | Meta's SAM-b | 358 MB | 94.7 M | 51096 ms/im |
| Ultralytics [YOLOv8n-seg](../tasks/segment.md) | **6.7 MB** (53.4x smaller) | **3.4 M** (27.9x less) | **59 ms** (866x faster) | | [MobileSAM](mobile-sam.md) | 40.7 MB | 10.1 M | 46122 ms/im |
| [FastSAM-s](fast-sam.md) with YOLOv8 backbone | 23.7 MB | 11.8 M | 115 ms/im |
| Ultralytics [YOLOv8n-seg](../tasks/segment.md) | **6.7 MB** (53.4x smaller) | **3.4 M** (27.9x less) | **59 ms/im** (866x faster) |
This comparison shows the order-of-magnitude differences in the model sizes and speeds. Whereas SAM presents unique capabilities for automatic segmenting, it is not a direct competitor to YOLOv8 segment models, which are smaller, faster and more efficient since they are dedicated to more targeted use cases. This comparison shows the order-of-magnitude differences in the model sizes and speeds between models. Whereas SAM presents unique capabilities for automatic segmenting, it is not a direct competitor to YOLOv8 segment models, which are smaller, faster and more efficient.
To reproduce this test: Tests run on a 2023 Apple M2 Macbook with 16GB of RAM. To reproduce this test:
```python ```python
from ultralytics import SAM, YOLO from ultralytics import FastSAM, SAM, YOLO
# Profile SAM-b # Profile SAM-b
model = SAM('sam_b.pt') model = SAM('sam_b.pt')
model.info() model.info()
model('ultralytics/assets') model('ultralytics/assets')
# Profile MobileSAM
model = SAM('mobile_sam.pt')
model.info()
model('ultralytics/assets')
# Profile FastSAM-s
model = FastSAM('FastSAM-s.pt')
model.info()
model('ultralytics/assets')
# Profile YOLOv8n-seg # Profile YOLOv8n-seg
model = YOLO('yolov8n-seg.pt') model = YOLO('yolov8n-seg.pt')
model.info() model.info()

@ -1,9 +0,0 @@
---
description: Learn how to use the ResizeLongestSide module in Ultralytics YOLO for automatic image resizing. Resize your images with ease.
keywords: ResizeLongestSide, Ultralytics YOLO, automatic image resizing, image resizing
---
## ResizeLongestSide
---
### ::: ultralytics.vit.sam.autosize.ResizeLongestSide
<br><br>

@ -18,6 +18,11 @@ keywords: SAM, VIT, computer vision models, build SAM models, build VIT models,
### ::: ultralytics.vit.sam.build.build_sam_vit_b ### ::: ultralytics.vit.sam.build.build_sam_vit_b
<br><br> <br><br>
## build_mobile_sam
---
### ::: ultralytics.vit.sam.build.build_mobile_sam
<br><br>
## _build_sam ## _build_sam
--- ---
### ::: ultralytics.vit.sam.build._build_sam ### ::: ultralytics.vit.sam.build._build_sam

@ -1,9 +0,0 @@
---
description: Learn about the SamAutomaticMaskGenerator module in Ultralytics YOLO, an automatic mask generator for image segmentation.
keywords: SamAutomaticMaskGenerator, Ultralytics YOLO, automatic mask generator, image segmentation
---
## SamAutomaticMaskGenerator
---
### ::: ultralytics.vit.sam.modules.mask_generator.SamAutomaticMaskGenerator
<br><br>

@ -1,9 +0,0 @@
---
description: Learn about PromptPredictor - a module in Ultralytics VIT SAM that predicts image captions based on prompts. Get started today!.
keywords: PromptPredictor, Ultralytics, YOLO, VIT SAM, image captioning, deep learning, computer vision
---
## PromptPredictor
---
### ::: ultralytics.vit.sam.modules.prompt_predictor.PromptPredictor
<br><br>

@ -0,0 +1,59 @@
---
description: Learn about the Conv2d_BN, MBConv, ConvLayer, Attention, BasicLayer, and TinyViT modules.
keywords: Conv2d_BN, MBConv, ConvLayer, Attention, BasicLayer, TinyViT
---
## Conv2d_BN
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.Conv2d_BN
<br><br>
## PatchEmbed
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.PatchEmbed
<br><br>
## MBConv
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.MBConv
<br><br>
## PatchMerging
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.PatchMerging
<br><br>
## ConvLayer
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.ConvLayer
<br><br>
## Mlp
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.Mlp
<br><br>
## Attention
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.Attention
<br><br>
## TinyViTBlock
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.TinyViTBlock
<br><br>
## BasicLayer
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.BasicLayer
<br><br>
## LayerNorm2d
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.LayerNorm2d
<br><br>
## TinyViT
---
### ::: ultralytics.vit.sam.modules.tiny_encoder.TinyViT
<br><br>

@ -0,0 +1,9 @@
---
description: Learn how to use FastSAM in Ultralytics YOLO to improve object detection accuracy and speed.
keywords: FastSAM, object detection, accuracy, speed, Ultralytics YOLO
---
## FastSAM
---
### ::: ultralytics.yolo.fastsam.model.FastSAM
<br><br>

@ -0,0 +1,9 @@
---
description: FastSAMPredictor API reference and usage guide for the Ultralytics YOLO object detection library.
keywords: FastSAMPredictor, API, reference, usage, guide, Ultralytics, YOLO, object detection, library
---
## FastSAMPredictor
---
### ::: ultralytics.yolo.fastsam.predict.FastSAMPredictor
<br><br>

@ -0,0 +1,9 @@
---
description: Learn how to use FastSAMPrompt in Ultralytics YOLO for fast and efficient object detection and tracking.
keywords: FastSAMPrompt, Ultralytics YOLO, object detection, tracking, fast, efficient
---
## FastSAMPrompt
---
### ::: ultralytics.yolo.fastsam.prompt.FastSAMPrompt
<br><br>

@ -0,0 +1,14 @@
---
description: Learn how to adjust bounding boxes to the image border in Ultralytics YOLO framework. Improve object detection accuracy by accounting for image borders.
keywords: adjust_bboxes_to_image_border, Ultralytics YOLO, object detection, bounding boxes, image border
---
## adjust_bboxes_to_image_border
---
### ::: ultralytics.yolo.fastsam.utils.adjust_bboxes_to_image_border
<br><br>
## bbox_iou
---
### ::: ultralytics.yolo.fastsam.utils.bbox_iou
<br><br>

@ -0,0 +1,9 @@
---
description: Learn about the FastSAMValidator module in Ultralytics YOLO. Validate and evaluate Segment Anything Model (SAM) datasets for object detection models with ease.
keywords: FastSAMValidator, Ultralytics YOLO, SAM datasets, object detection, validation, evaluation
---
## FastSAMValidator
---
### ::: ultralytics.yolo.fastsam.val.FastSAMValidator
<br><br>

@ -123,6 +123,11 @@ keywords: Ultralytics, YOLO, Utils Ops, Functions, coco80_to_coco91_class, scale
### ::: ultralytics.yolo.utils.ops.process_mask_native ### ::: ultralytics.yolo.utils.ops.process_mask_native
<br><br> <br><br>
## scale_masks
---
### ::: ultralytics.yolo.utils.ops.scale_masks
<br><br>
## scale_coords ## scale_coords
--- ---
### ::: ultralytics.yolo.utils.ops.scale_coords ### ::: ultralytics.yolo.utils.ops.scale_coords

@ -168,6 +168,7 @@ nav:
- YOLOv7: models/yolov7.md - YOLOv7: models/yolov7.md
- YOLOv8: models/yolov8.md - YOLOv8: models/yolov8.md
- SAM (Segment Anything Model): models/sam.md - SAM (Segment Anything Model): models/sam.md
- MobileSAM (Mobile Segment Anything Model): models/mobile-sam.md
- FastSAM (Fast Segment Anything Model): models/fast-sam.md - FastSAM (Fast Segment Anything Model): models/fast-sam.md
- YOLO-NAS (Neural Architecture Search): models/yolo-nas.md - YOLO-NAS (Neural Architecture Search): models/yolo-nas.md
- RT-DETR (Realtime Detection Transformer): models/rtdetr.md - RT-DETR (Realtime Detection Transformer): models/rtdetr.md
@ -282,15 +283,13 @@ nav:
- val: reference/vit/rtdetr/val.md - val: reference/vit/rtdetr/val.md
- sam: - sam:
- amg: reference/vit/sam/amg.md - amg: reference/vit/sam/amg.md
- autosize: reference/vit/sam/autosize.md
- build: reference/vit/sam/build.md - build: reference/vit/sam/build.md
- model: reference/vit/sam/model.md - model: reference/vit/sam/model.md
- modules: - modules:
- decoders: reference/vit/sam/modules/decoders.md - decoders: reference/vit/sam/modules/decoders.md
- encoders: reference/vit/sam/modules/encoders.md - encoders: reference/vit/sam/modules/encoders.md
- mask_generator: reference/vit/sam/modules/mask_generator.md
- prompt_predictor: reference/vit/sam/modules/prompt_predictor.md
- sam: reference/vit/sam/modules/sam.md - sam: reference/vit/sam/modules/sam.md
- tiny_encoder: reference/vit/sam/modules/tiny_encoder.md
- transformer: reference/vit/sam/modules/transformer.md - transformer: reference/vit/sam/modules/transformer.md
- predict: reference/vit/sam/predict.md - predict: reference/vit/sam/predict.md
- utils: - utils:
@ -319,6 +318,12 @@ nav:
- results: reference/yolo/engine/results.md - results: reference/yolo/engine/results.md
- trainer: reference/yolo/engine/trainer.md - trainer: reference/yolo/engine/trainer.md
- validator: reference/yolo/engine/validator.md - validator: reference/yolo/engine/validator.md
- fastsam:
- model: reference/yolo/fastsam/model.md
- predict: reference/yolo/fastsam/predict.md
- prompt: reference/yolo/fastsam/prompt.md
- utils: reference/yolo/fastsam/utils.md
- val: reference/yolo/fastsam/val.md
- nas: - nas:
- model: reference/yolo/nas/model.md - model: reference/yolo/nas/model.md
- predict: reference/yolo/nas/predict.md - predict: reference/yolo/nas/predict.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.133' __version__ = '8.0.134'
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR from ultralytics.vit.rtdetr import RTDETR

@ -1,5 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from .build import build_sam # noqa from .model import SAM
from .model import SAM # noqa from .predict import Predictor
from .modules.prompt_predictor import PromptPredictor # noqa
# from .build import build_sam
__all__ = 'SAM', 'Predictor' # tuple or list

@ -1,94 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
class ResizeLongestSide:
"""
Resizes images to the longest side 'target_length', as well as provides
methods for resizing coordinates and boxes. Provides methods for
transforming both numpy array and batched torch tensors.
"""
def __init__(self, target_length: int) -> None:
self.target_length = target_length
def apply_image(self, image: np.ndarray) -> np.ndarray:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return np.array(resize(to_pil_image(image), target_size))
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format.
"""
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
"""
Expects batched images with shape BxCxHxW and float format. This
transformation may not exactly match apply_image. apply_image is
the transformation expected by the model.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
return F.interpolate(image, target_size, mode='bilinear', align_corners=False, antialias=True)
def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""
Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format.
"""
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)

@ -14,6 +14,7 @@ from ...yolo.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import Sam from .modules.sam import Sam
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer from .modules.transformer import TwoWayTransformer
@ -50,20 +51,45 @@ def build_sam_vit_b(checkpoint=None):
) )
def _build_sam( def build_mobile_sam(checkpoint=None):
encoder_embed_dim, """Build and return Mobile Segment Anything Model (Mobile-SAM)."""
return _build_sam(
encoder_embed_dim=[64, 128, 160, 320],
encoder_depth=[2, 2, 6, 2],
encoder_num_heads=[2, 4, 5, 10],
encoder_global_attn_indexes=None,
mobile_sam=True,
checkpoint=checkpoint,
)
def _build_sam(encoder_embed_dim,
encoder_depth, encoder_depth,
encoder_num_heads, encoder_num_heads,
encoder_global_attn_indexes, encoder_global_attn_indexes,
checkpoint=None, checkpoint=None,
): mobile_sam=False):
"""Builds the selected SAM model architecture.""" """Builds the selected SAM model architecture."""
prompt_embed_dim = 256 prompt_embed_dim = 256
image_size = 1024 image_size = 1024
vit_patch_size = 16 vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size image_embedding_size = image_size // vit_patch_size
sam = Sam( image_encoder = (TinyViT(
image_encoder=ImageEncoderViT( img_size=1024,
in_chans=3,
num_classes=1000,
embed_dims=encoder_embed_dim,
depths=encoder_depth,
num_heads=encoder_num_heads,
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8,
) if mobile_sam else ImageEncoderViT(
depth=encoder_depth, depth=encoder_depth,
embed_dim=encoder_embed_dim, embed_dim=encoder_embed_dim,
img_size=image_size, img_size=image_size,
@ -76,7 +102,9 @@ def _build_sam(
global_attn_indexes=encoder_global_attn_indexes, global_attn_indexes=encoder_global_attn_indexes,
window_size=14, window_size=14,
out_chans=prompt_embed_dim, out_chans=prompt_embed_dim,
), ))
sam = Sam(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder( prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim, embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size), image_embedding_size=(image_embedding_size, image_embedding_size),
@ -98,20 +126,22 @@ def _build_sam(
pixel_mean=[123.675, 116.28, 103.53], pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375], pixel_std=[58.395, 57.12, 57.375],
) )
sam.eval()
if checkpoint is not None: if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint) checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, 'rb') as f: with open(checkpoint, 'rb') as f:
state_dict = torch.load(f) state_dict = torch.load(f)
sam.load_state_dict(state_dict) sam.load_state_dict(state_dict)
sam.eval()
# sam.load_state_dict(torch.load(checkpoint), strict=True)
# sam.eval()
return sam return sam
sam_model_map = { sam_model_map = {
# "default": build_sam_vit_h,
'sam_h.pt': build_sam_vit_h, 'sam_h.pt': build_sam_vit_h,
'sam_l.pt': build_sam_vit_l, 'sam_l.pt': build_sam_vit_l,
'sam_b.pt': build_sam_vit_b, } 'sam_b.pt': build_sam_vit_b,
'mobile_sam.pt': build_mobile_sam, }
def build_sam(ckpt='sam_b.pt'): def build_sam(ckpt='sam_b.pt'):

@ -4,8 +4,8 @@ SAM model interface
""" """
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils.torch_utils import model_info
from ...yolo.utils.torch_utils import model_info
from .build import build_sam from .build import build_sam
from .predict import Predictor from .predict import Predictor
@ -20,16 +20,16 @@ class SAM:
self.task = 'segment' # required self.task = 'segment' # required
self.predictor = None # reuse predictor self.predictor = None # reuse predictor
def predict(self, source, stream=False, **kwargs): def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source.""" """Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict') overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
overrides.update(kwargs) # prefer kwargs overrides.update(kwargs) # prefer kwargs
if not self.predictor: if not self.predictor:
self.predictor = Predictor(overrides=overrides) self.predictor = Predictor(overrides=overrides)
self.predictor.setup_model(model=self.model) self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides) self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream) return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
def train(self, **kwargs): def train(self, **kwargs):
"""Function trains models but raises an error as SAM models do not support training.""" """Function trains models but raises an error as SAM models do not support training."""
@ -39,9 +39,9 @@ class SAM:
"""Run validation given dataset.""" """Run validation given dataset."""
raise NotImplementedError("SAM models don't support validation") raise NotImplementedError("SAM models don't support validation")
def __call__(self, source=None, stream=False, **kwargs): def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection.""" """Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs) return self.predict(source, stream, bboxes, points, labels, **kwargs)
def __getattr__(self, attr): def __getattr__(self, attr):
"""Raises error if object has no requested attribute.""" """Raises error if object has no requested attribute."""

@ -1,353 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from ..amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, box_xyxy_to_xywh,
build_all_layer_point_grids, calculate_stability_score, coco_encode_rle, generate_crop_boxes,
is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
uncrop_masks, uncrop_points)
from .prompt_predictor import PromptPredictor
from .sam import Sam
class SamAutomaticMaskGenerator:
def __init__(
self,
model: Sam,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = 'binary_mask',
) -> None:
"""
Using a SAM model, generates masks for the entire image.
Generates a grid of point prompts over the image, then filters
low quality and duplicate masks. The default settings are chosen
for SAM with a ViT-H backbone.
Arguments:
model (Sam): The SAM model to use for mask prediction.
points_per_side (int, None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray), None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove disconnected regions and holes in masks with area smaller
than min_mask_region_area. Requires opencv.
output_mode (str): The form masks are returned in. Can be 'binary_mask',
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
"""
assert (points_per_side is None) != (point_grids is None), \
'Exactly one of points_per_side or point_grid must be provided.'
if points_per_side is not None:
self.point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layers,
crop_n_points_downscale_factor,
)
elif point_grids is not None:
self.point_grids = point_grids
else:
raise ValueError("Can't have both points_per_side and point_grid be None.")
assert output_mode in {'binary_mask', 'uncompressed_rle', 'coco_rle'}, f'Unknown output_mode {output_mode}.'
if output_mode == 'coco_rle':
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
if min_mask_region_area > 0:
import cv2 # type: ignore # noqa: F401
self.predictor = PromptPredictor(model)
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.stability_score_offset = stability_score_offset
self.box_nms_thresh = box_nms_thresh
self.crop_n_layers = crop_n_layers
self.crop_nms_thresh = crop_nms_thresh
self.crop_overlap_ratio = crop_overlap_ratio
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
# TODO: Temporary implementation for compatibility
def __call__(self, image: np.ndarray, augment=False, visualize=False) -> List[Dict[str, Any]]:
return self.generate(image)
@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Generates masks for the given image.
Arguments:
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
Returns:
list(dict(str, any)): A list over records for masks. Each record is a dict containing the following keys:
segmentation (dict(str, any), np.ndarray): The mask. If
output_mode='binary_mask', is an array of shape HW. Otherwise,
is a dictionary containing the RLE.
bbox (list(float)): The box around the mask, in XYWH format.
area (int): The area in pixels of the mask.
predicted_iou (float): The model's own prediction of the mask's
quality. This is filtered by the pred_iou_thresh parameter.
point_coords (list(list(float))): The point coordinates input
to the model to generate this mask.
stability_score (float): A measure of the mask's quality. This
is filtered on using the stability_score_thresh parameter.
crop_box (list(float)): The crop of the image used to generate
the mask, given in XYWH format.
"""
# Generate masks
mask_data = self._generate_masks(image)
# Filter small disconnected regions and holes in masks
if self.min_mask_region_area > 0:
mask_data = self.postprocess_small_regions(
mask_data,
self.min_mask_region_area,
max(self.box_nms_thresh, self.crop_nms_thresh),
)
# Encode masks
if self.output_mode == 'coco_rle':
mask_data['segmentations'] = [coco_encode_rle(rle) for rle in mask_data['rles']]
elif self.output_mode == 'binary_mask':
mask_data['segmentations'] = [rle_to_mask(rle) for rle in mask_data['rles']]
else:
mask_data['segmentations'] = mask_data['rles']
# Write mask records
curr_anns = []
for idx in range(len(mask_data['segmentations'])):
ann = {
'segmentation': mask_data['segmentations'][idx],
'area': area_from_rle(mask_data['rles'][idx]),
'bbox': box_xyxy_to_xywh(mask_data['boxes'][idx]).tolist(),
'predicted_iou': mask_data['iou_preds'][idx].item(),
'point_coords': [mask_data['points'][idx].tolist()],
'stability_score': mask_data['stability_score'][idx].item(),
'crop_box': box_xyxy_to_xywh(mask_data['crop_boxes'][idx]).tolist(), }
curr_anns.append(ann)
return curr_anns
def _generate_masks(self, image: np.ndarray) -> MaskData:
orig_size = image.shape[:2]
crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio)
# Iterate over image crops
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
data.cat(crop_data)
# Remove duplicate masks between crops
if len(crop_boxes) > 1:
# Prefer masks from smaller crops
scores = 1 / box_area(data['crop_boxes'])
scores = scores.to(data['boxes'].device)
keep_by_nms = batched_nms(
data['boxes'].float(),
scores,
torch.zeros_like(data['boxes'][:, 0]), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
data.to_numpy()
return data
def _process_crop(
self,
image: np.ndarray,
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
) -> MaskData:
# Crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[:2]
self.predictor.set_image(cropped_im)
# Get points for this crop
points_scale = np.array(cropped_im_size)[None, ::-1]
points_for_image = self.point_grids[crop_layer_idx] * points_scale
# Generate masks for this crop in batches
data = MaskData()
for (points, ) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
data.cat(batch_data)
del batch_data
self.predictor.reset_image()
# Remove duplicates within this crop.
keep_by_nms = batched_nms(
data['boxes'].float(),
data['iou_preds'],
torch.zeros_like(data['boxes'][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
# Return to the original image frame
data['boxes'] = uncrop_boxes_xyxy(data['boxes'], crop_box)
data['points'] = uncrop_points(data['points'], crop_box)
data['crop_boxes'] = torch.tensor([crop_box for _ in range(len(data['rles']))])
return data
def _process_batch(
self,
points: np.ndarray,
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
) -> MaskData:
orig_h, orig_w = orig_size
# Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
multimask_output=True,
return_logits=True,
)
# Serialize predictions and store in MaskData
data = MaskData(
masks=masks.flatten(0, 1),
iou_preds=iou_preds.flatten(0, 1),
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
)
del masks
# Filter by predicted IoU
if self.pred_iou_thresh > 0.0:
keep_mask = data['iou_preds'] > self.pred_iou_thresh
data.filter(keep_mask)
# Calculate stability score
data['stability_score'] = calculate_stability_score(data['masks'], self.predictor.model.mask_threshold,
self.stability_score_offset)
if self.stability_score_thresh > 0.0:
keep_mask = data['stability_score'] >= self.stability_score_thresh
data.filter(keep_mask)
# Threshold masks and calculate boxes
data['masks'] = data['masks'] > self.predictor.model.mask_threshold
data['boxes'] = batched_mask_to_box(data['masks'])
# Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(data['boxes'], crop_box, [0, 0, orig_w, orig_h])
if not torch.all(keep_mask):
data.filter(keep_mask)
# Compress to RLE
data['masks'] = uncrop_masks(data['masks'], crop_box, orig_h, orig_w)
data['rles'] = mask_to_rle_pytorch(data['masks'])
del data['masks']
return data
@staticmethod
def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData:
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates.
Edits mask_data in place.
Requires open-cv as a dependency.
"""
if len(mask_data['rles']) == 0:
return mask_data
# Filter small disconnected regions and holes
new_masks = []
scores = []
for rle in mask_data['rles']:
mask = rle_to_mask(rle)
mask, changed = remove_small_regions(mask, min_area, mode='holes')
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode='islands')
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)
# Only recalculate RLEs for masks that have changed
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
mask_data['rles'][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
mask_data['boxes'][i_mask] = boxes[i_mask] # update res directly
mask_data.filter(keep_by_nms)
return mask_data

@ -1,242 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import Optional, Tuple
import numpy as np
import torch
from ..autosize import ResizeLongestSide
from .sam import Sam
class PromptPredictor:
def __init__(self, sam_model: Sam) -> None:
"""
Uses SAM to calculate the image embedding for an image, and then
allow repeated, efficient mask prediction given prompts.
Arguments:
sam_model (Sam): The model to use for mask prediction.
"""
super().__init__()
self.model = sam_model
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
self.reset_image()
def set_image(self, image: np.ndarray, image_format: str = 'RGB') -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method.
Arguments:
image (np.ndarray): The image for calculating masks. Expects an
image in HWC uint8 format, with pixel values in [0, 255].
image_format (str): The color format of the image, in ['RGB', 'BGR'].
"""
assert image_format in {'RGB', 'BGR'}, f"image_format must be in ['RGB', 'BGR'], is {image_format}."
if image_format != self.model.image_format:
image = image[..., ::-1]
# Transform the image to the form expected by the model
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
self.set_torch_image(input_image_torch, image.shape[:2])
@torch.no_grad()
def set_torch_image(self, transformed_image: torch.Tensor, original_image_size: Tuple[int, ...]) -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method. Expects the input
image to be already transformed to the format expected by the model.
Arguments:
transformed_image (torch.Tensor): The input image, with shape
1x3xHxW, which has been transformed with ResizeLongestSide.
original_image_size (tuple(int, int)): The size of the image
before transformation, in (H, W) format.
"""
if len(transformed_image.shape) != 4 \
or transformed_image.shape[1] != 3 \
or max(*transformed_image.shape[2:]) != self.model.image_encoder.img_size:
raise ValueError('set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}.')
self.reset_image()
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.features = self.model.image_encoder(input_image)
self.is_image_set = True
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Predict masks for the given input prompts, using the currently set image.
Arguments:
point_coords (np.ndarray, None): A Nx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (np.ndarray, None): A length N array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
box (np.ndarray, None): A length 4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form 1xHxW, where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if not self.is_image_set:
raise RuntimeError('An image must be set with .set_image(...) before mask prediction.')
# Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
if point_coords is not None:
assert (point_labels is not None), 'point_labels must be supplied if point_coords is supplied.'
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = mask_input_torch[None, :, :, :]
masks, iou_predictions, low_res_masks = self.predict_torch(
coords_torch,
labels_torch,
box_torch,
mask_input_torch,
multimask_output,
return_logits=return_logits,
)
masks_np = masks[0].detach().cpu().numpy()
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
return masks_np, iou_predictions_np, low_res_masks_np
@torch.no_grad()
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using ResizeLongestSide.
Arguments:
point_coords (torch.Tensor, None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor, None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
boxes (np.ndarray, None): A Bx4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where
for SAM, H=W=256. Masks returned by a previous iteration of the
predict method do not need further transformation.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
if not self.is_image_set:
raise RuntimeError('An image must be set with .set_image(...) before mask prediction.')
points = (point_coords, point_labels) if point_coords is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
# Predict masks
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
def get_image_embedding(self) -> torch.Tensor:
"""
Returns the image embeddings for the currently set image, with
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self.is_image_set:
raise RuntimeError('An image must be set with .set_image(...) to generate an embedding.')
assert self.features is not None, 'Features must exist if an image has been set.'
return self.features
@property
def device(self) -> torch.device:
return self.model.device
def reset_image(self) -> None:
"""Resets the currently set image."""
self.is_image_set = False
self.features = None
self.orig_h = None
self.orig_w = None
self.input_h = None
self.input_w = None

@ -0,0 +1,653 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# --------------------------------------------------------
# TinyViT Model Architecture
# Copyright (c) 2022 Microsoft
# Adapted from LeViT and Swin Transformer
# LeViT: (https://github.com/facebookresearch/levit)
# Swin: (https://github.com/microsoft/swin-transformer)
# Build the TinyViT Model
# --------------------------------------------------------
import itertools
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from ultralytics.yolo.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
w.size(0),
w.shape[2:],
stride=self.c.stride,
padding=self.c.padding,
dilation=self.c.dilation,
groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
# NOTE: This module and timm package is needed only for training.
# from ultralytics.yolo.utils.checks import check_requirements
# check_requirements('timm')
# from timm.models.layers import DropPath as TimmDropPath
# from timm.models.layers import trunc_normal_
# class DropPath(TimmDropPath):
#
# def __init__(self, drop_prob=None):
# super().__init__(drop_prob=drop_prob)
# self.drop_prob = drop_prob
#
# def __repr__(self):
# msg = super().__repr__()
# msg += f'(drop_prob={self.drop_prob})'
# return msg
class PatchEmbed(nn.Module):
def __init__(self, in_chans, embed_dim, resolution, activation):
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * \
self.patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
n = embed_dim
self.seq = nn.Sequential(
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
activation(),
Conv2d_BN(n // 2, n, 3, 2, 1),
)
def forward(self, x):
return self.seq(x)
class MBConv(nn.Module):
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
self.out_chans = out_chans
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
self.act1 = activation()
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
self.act2 = activation()
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
self.act3 = activation()
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
x = self.act3(x)
return x
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, out_dim, activation):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c = 2
if (out_dim == 320 or out_dim == 448 or out_dim == 576):
stride_c = 1
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
# (B, C, H, W)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = x.flatten(2).transpose(1, 2)
return x
class ConvLayer(nn.Module):
def __init__(
self,
dim,
input_resolution,
depth,
activation,
drop_path=0.,
downsample=None,
use_checkpoint=False,
out_dim=None,
conv_expand_ratio=4.,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
MBConv(
dim,
dim,
conv_expand_ratio,
activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.act = act_layer()
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(torch.nn.Module):
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=(14, 14),
):
super().__init__()
# (h, w)
assert isinstance(resolution, tuple) and len(resolution) == 2
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.norm = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, h)
self.proj = nn.Linear(self.dh, dim)
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,N,C)
B, N, _ = x.shape
# Normalization
x = self.norm(x)
qkv = self.qkv(x)
# (B, N, num_heads, d)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
# (B, num_heads, N, d)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
self.ab = self.ab.to(self.attention_biases.device)
attn = ((q @ k.transpose(-2, -1)) * self.scale +
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class TinyViTBlock(nn.Module):
r""" TinyViT Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int, int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
local_conv_size (int): the kernel size of the convolution between
Attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
local_conv_size=3,
activation=nn.GELU,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
self.window_size = window_size
self.mlp_ratio = mlp_ratio
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
assert dim % num_heads == 0, 'dim must be divisible by num_heads'
head_dim = dim // num_heads
window_resolution = (window_size, window_size)
self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_activation = activation
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
pad = local_conv_size // 2
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size - H % self.window_size) % self.window_size
pad_r = (self.window_size - W % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# window partition
x = x.view(B, nH, self.window_size, nW, self.window_size,
C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
x = self.attn(x)
# window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.view(B, L, C)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x + self.drop_path(self.mlp(x))
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
class BasicLayer(nn.Module):
""" A basic TinyViT layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
out_dim (int | optional): the output dimension of the layer. Default: None
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
activation=nn.GELU,
out_dim=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
TinyViTBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
local_conv_size=local_conv_size,
activation=activation,
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class TinyViT(nn.Module):
def __init__(
self,
img_size=224,
in_chans=3,
num_classes=1000,
embed_dims=[96, 192, 384, 768],
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=1.0,
):
super().__init__()
self.img_size = img_size
self.num_classes = num_classes
self.depths = depths
self.num_layers = len(depths)
self.mlp_ratio = mlp_ratio
activation = nn.GELU
self.patch_embed = PatchEmbed(in_chans=in_chans,
embed_dim=embed_dims[0],
resolution=img_size,
activation=activation)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs = dict(
dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(i_layer + 1,
len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs,
)
else:
layer = BasicLayer(num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs)
self.layers.append(layer)
# Classifier head
self.norm_head = nn.LayerNorm(embed_dims[-1])
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
# init weights
self.apply(self._init_weights)
self.set_layer_lr_decay(layer_lr_decay)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dims[-1],
256,
kernel_size=1,
bias=False,
),
LayerNorm2d(256),
nn.Conv2d(
256,
256,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(256),
)
def set_layer_lr_decay(self, layer_lr_decay):
decay_rate = layer_lr_decay
# layers -> blocks (depth)
depth = sum(self.depths)
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
def _set_lr_scale(m, scale):
for p in m.parameters():
p.lr_scale = scale
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
i = 0
for layer in self.layers:
for block in layer.blocks:
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
i += 1
if layer.downsample is not None:
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
assert i == depth
for m in [self.norm_head, self.head]:
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
for k, p in self.named_parameters():
p.param_name = k
def _check_lr_scale(m):
for p in m.parameters():
assert hasattr(p, 'lr_scale'), p.param_name
self.apply(_check_lr_scale)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'attention_biases'}
def forward_features(self, x):
# x: (N, C, H, W)
x = self.patch_embed(x)
x = self.layers[0](x)
start_i = 1
for i in range(start_i, len(self.layers)):
layer = self.layers[i]
x = layer(x)
B, _, C = x.size()
x = x.view(B, 64, 64, C)
x = x.permute(0, 3, 1, 2)
x = self.neck(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x

@ -2,32 +2,298 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torchvision
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CFG, ops
from ultralytics.yolo.utils.torch_utils import select_device from ultralytics.yolo.utils.torch_utils import select_device
from .modules.mask_generator import SamAutomaticMaskGenerator from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
from .build import build_sam
class Predictor(BasePredictor): class Predictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None):
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
super().__init__(cfg, overrides, _callbacks)
# SAM needs retina_masks=True, or the results would be a mess.
self.args.retina_masks = True
# Args for set_image
self.im = None
self.features = None
# Args for segment everything
self.segment_all = False
def preprocess(self, im): def preprocess(self, im):
"""Prepares input image for inference.""" """Prepares input image before inference.
# TODO: Only support bs=1 for now
# im = ResizeLongestSide(1024).apply_image(im[0]) Args:
# im = torch.as_tensor(im, device=self.device) im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
# im = im.permute(2, 0, 1).contiguous()[None, :, :, :] """
return im[0] if self.im is not None:
return self.im
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
img = im.to(self.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
if not_tensor:
img = (img - self.mean) / self.std
return img
def pre_transform(self, im):
"""Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Return: A list of transformed imgs.
"""
assert len(im) == 1, 'SAM model has not supported batch inference yet!'
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
"""
Predict masks for the given input prompts, using the currently set image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if all([i is None for i in [bboxes, points, masks]]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
"""
Predict masks for the given input prompts, using the currently set image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
features = self.model.image_encoder(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = np.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes *= r
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device)
masks = masks[:, None, :, :]
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=bboxes,
masks=masks,
)
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
image_embeddings=features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def generate(self,
im,
crop_n_layers=0,
crop_overlap_ratio=512 / 1500,
crop_downscale_factor=1,
point_grids=None,
points_stride=32,
points_batch_size=64,
conf_thres=0.88,
stability_score_thresh=0.95,
stability_score_offset=0.95,
crop_nms_thresh=0.7):
"""Segment the whole image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray), None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
points_stride (int, None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_batch_size (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
conf_thres (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
"""
self.segment_all = True
ih, iw = im.shape[2:]
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
if point_grids is None:
point_grids = build_all_layer_point_grids(
points_stride,
crop_n_layers,
crop_downscale_factor,
)
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
x1, y1, x2, y2 = crop_region
w, h = x2 - x1, y2 - y1
area = torch.tensor(w * h, device=im.device)
points_scale = np.array([[w, h]]) # w, h
# Crop image and interpolate to input size
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
# (num_points, 2)
points_for_image = point_grids[layer_idx] * points_scale
crop_masks, crop_scores, crop_bboxes = [], [], []
for (points, ) in batch_iterator(points_batch_size, points_for_image):
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
# Interpolate predicted masks to input size
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
idx = pred_score > conf_thres
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
stability_score_offset)
idx = stability_score > stability_score_thresh
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
# Bool type is much more memory-efficient.
pred_mask = pred_mask > self.model.mask_threshold
# (N, 4)
pred_bbox = batched_mask_to_box(pred_mask).float()
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
if not torch.all(keep_mask):
pred_bbox = pred_bbox[keep_mask]
pred_mask = pred_mask[keep_mask]
pred_score = pred_score[keep_mask]
crop_masks.append(pred_mask)
crop_bboxes.append(pred_bbox)
crop_scores.append(pred_score)
# Do nms within this crop
crop_masks = torch.cat(crop_masks)
crop_bboxes = torch.cat(crop_bboxes)
crop_scores = torch.cat(crop_scores)
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
crop_scores = crop_scores[keep]
pred_masks.append(crop_masks)
pred_bboxes.append(crop_bboxes)
pred_scores.append(crop_scores)
region_areas.append(area.expand(len(crop_masks)))
pred_masks = torch.cat(pred_masks)
pred_bboxes = torch.cat(pred_bboxes)
pred_scores = torch.cat(pred_scores)
region_areas = torch.cat(region_areas)
# Remove duplicate masks between crops
if len(crop_regions) > 1:
scores = 1 / region_areas
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
pred_masks = pred_masks[keep]
pred_bboxes = pred_bboxes[keep]
pred_scores = pred_scores[keep]
return pred_masks, pred_scores, pred_bboxes
def setup_model(self, model): def setup_model(self, model):
"""Set up YOLO model with specified thresholds and device.""" """Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device) device = select_device(self.args.device)
if model is None:
model = build_sam(self.args.model)
model.eval() model.eval()
self.model = SamAutomaticMaskGenerator(model.to(device), self.model = model.to(device)
pred_iou_thresh=self.args.conf,
box_nms_thresh=self.args.iou)
self.device = device self.device = device
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
# TODO: Temporary settings for compatibility # TODO: Temporary settings for compatibility
self.model.pt = False self.model.pt = False
self.model.triton = False self.model.triton = False
@ -35,20 +301,96 @@ class Predictor(BasePredictor):
self.model.fp16 = False self.model.fp16 = False
self.done_warmup = True self.done_warmup = True
def postprocess(self, preds, path, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Postprocesses inference output predictions to create detection masks for objects.""" """Postprocesses inference output predictions to create detection masks for objects."""
names = dict(enumerate(list(range(len(preds))))) # (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2]
pred_bboxes = preds[2] if self.segment_all else None
names = dict(enumerate([str(i) for i in range(len(pred_masks))]))
results = [] results = []
# TODO for i, masks in enumerate([pred_masks]):
for i, pred in enumerate([preds]):
masks = torch.from_numpy(np.stack([p['segmentation'] for p in pred], axis=0))
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
masks = masks > self.model.mask_threshold # to bool
path = self.batch[0] path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks)) results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
# Reset segment-all mode.
self.segment_all = False
return results return results
# def __call__(self, source=None, model=None, stream=False): def setup_source(self, source):
# frame = cv2.imread(source) """Sets up source and inference mode."""
# preds = self.model.generate(frame) if source is not None:
# return self.postprocess(preds, source, frame) super().setup_source(source)
def set_image(self, image):
"""Set image in advance.
Args:
image (str | np.ndarray): image file path or np.ndarray image by cv2.
"""
if self.model is None:
model = build_sam(self.args.model)
self.setup_model(model)
self.setup_source(image)
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
self.im = im
break
def reset_image(self):
self.im = None
self.features = None
@staticmethod
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates. Requires open-cv as a dependency.
Args:
masks (torch.Tensor): Masks, (N, H, W).
min_area (int): Minimum area threshold.
nms_thresh (float): NMS threshold.
"""
if len(masks) == 0:
return masks
# Filter small disconnected regions and holes
new_masks = []
scores = []
for mask in masks:
mask = mask.cpu().numpy()
mask, changed = remove_small_regions(mask, min_area, mode='holes')
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode='islands')
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
new_masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(new_masks)
keep = torchvision.ops.nms(
boxes.float(),
torch.as_tensor(scores),
nms_thresh,
)
# Only recalculate masks for masks that have changed
for i in keep:
if scores[i] == 0.0:
masks[i] = new_masks[i]
return masks[keep]

@ -1,8 +1,6 @@
from pathlib import Path from pathlib import Path
from ultralytics import YOLO from ultralytics import SAM, YOLO
from ultralytics.vit.sam import PromptPredictor, build_sam
from ultralytics.yolo.utils.torch_utils import select_device
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
@ -16,33 +14,21 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
output_dir (str | None | optional): Directory to save the annotated results. output_dir (str | None | optional): Directory to save the annotated results.
Defaults to a 'labels' folder in the same directory as 'data'. Defaults to a 'labels' folder in the same directory as 'data'.
""" """
device = select_device(device)
det_model = YOLO(det_model) det_model = YOLO(det_model)
sam_model = build_sam(sam_model) sam_model = SAM(sam_model)
det_model.to(device)
sam_model.to(device)
if not output_dir: if not output_dir:
output_dir = Path(str(data)).parent / 'labels' output_dir = Path(str(data)).parent / 'labels'
Path(output_dir).mkdir(exist_ok=True, parents=True) Path(output_dir).mkdir(exist_ok=True, parents=True)
prompt_predictor = PromptPredictor(sam_model) det_results = det_model(data, stream=True, device=device)
det_results = det_model(data, stream=True)
for result in det_results: for result in det_results:
boxes = result.boxes.xyxy # Boxes object for bbox outputs boxes = result.boxes.xyxy # Boxes object for bbox outputs
class_ids = result.boxes.cls.int().tolist() # noqa class_ids = result.boxes.cls.int().tolist() # noqa
if len(class_ids): if len(class_ids):
prompt_predictor.set_image(result.orig_img) sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
masks, _, _ = prompt_predictor.predict_torch( segments = sam_results[0].masks.xyn # noqa
point_coords=None,
point_labels=None,
boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]),
multimask_output=False,
)
result.update(masks=masks.squeeze(1))
segments = result.masks.xyn # noqa
with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f: with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
for i in range(len(segments)): for i in range(len(segments)):

@ -538,13 +538,14 @@ class RandomFlip:
class LetterBox: class LetterBox:
"""Resize image and padding for detection, instance segmentation, pose.""" """Resize image and padding for detection, instance segmentation, pose."""
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32): def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
"""Initialize LetterBox object with specific parameters.""" """Initialize LetterBox object with specific parameters."""
self.new_shape = new_shape self.new_shape = new_shape
self.auto = auto self.auto = auto
self.scaleFill = scaleFill self.scaleFill = scaleFill
self.scaleup = scaleup self.scaleup = scaleup
self.stride = stride self.stride = stride
self.center = center # Put the image in the middle or top-left
def __call__(self, labels=None, image=None): def __call__(self, labels=None, image=None):
"""Return updated labels and image with added border.""" """Return updated labels and image with added border."""
@ -572,6 +573,7 @@ class LetterBox:
new_unpad = (new_shape[1], new_shape[0]) new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
if self.center:
dw /= 2 # divide padding into 2 sides dw /= 2 # divide padding into 2 sides
dh /= 2 dh /= 2
if labels.get('ratio_pad'): if labels.get('ratio_pad'):
@ -579,8 +581,8 @@ class LetterBox:
if shape[::-1] != new_unpad: # resize if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=(114, 114, 114)) # add border value=(114, 114, 114)) # add border

@ -131,6 +131,11 @@ class BasePredictor:
img /= 255 # 0 - 255 to 0.0 - 1.0 img /= 255 # 0 - 255 to 0.0 - 1.0
return img return img
def inference(self, im, *args, **kwargs):
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
return self.model(im, augment=self.args.augment, visualize=visualize)
def pre_transform(self, im): def pre_transform(self, im):
"""Pre-transform input image before inference. """Pre-transform input image before inference.
@ -181,13 +186,13 @@ class BasePredictor:
"""Post-processes predictions for an image and returns them.""" """Post-processes predictions for an image and returns them."""
return preds return preds
def __call__(self, source=None, model=None, stream=False): def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
"""Performs inference on an image or stream.""" """Performs inference on an image or stream."""
self.stream = stream self.stream = stream
if stream: if stream:
return self.stream_inference(source, model) return self.stream_inference(source, model, *args, **kwargs)
else: else:
return list(self.stream_inference(source, model)) # merge list of Result into one return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
def predict_cli(self, source=None, model=None): def predict_cli(self, source=None, model=None):
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode.""" """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
@ -209,7 +214,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode() @smart_inference_mode()
def stream_inference(self, source=None, model=None): def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file.""" """Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose: if self.args.verbose:
LOGGER.info('') LOGGER.info('')
@ -236,8 +241,6 @@ class BasePredictor:
self.run_callbacks('on_predict_batch_start') self.run_callbacks('on_predict_batch_start')
self.batch = batch self.batch = batch
path, im0s, vid_cap, s = batch path, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path[0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
# Preprocess # Preprocess
with profilers[0]: with profilers[0]:
@ -245,7 +248,7 @@ class BasePredictor:
# Inference # Inference
with profilers[1]: with profilers[1]:
preds = self.model(im, augment=self.args.augment, visualize=visualize) preds = self.inference(im, *args, **kwargs)
# Postprocess # Postprocess
with profilers[2]: with profilers[2]:

@ -170,7 +170,7 @@ class Results(SimpleClass):
font='Arial.ttf', font='Arial.ttf',
pil=False, pil=False,
img=None, img=None,
img_gpu=None, im_gpu=None,
kpt_line=True, kpt_line=True,
labels=True, labels=True,
boxes=True, boxes=True,
@ -188,7 +188,7 @@ class Results(SimpleClass):
font (str): The font to use for the text. font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image. pil (bool): Whether to return the image as a PIL Image.
img (numpy.ndarray): Plot to another image. if not, plot to original image. img (numpy.ndarray): Plot to another image. if not, plot to original image.
img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting. im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
kpt_line (bool): Whether to draw lines connecting keypoints. kpt_line (bool): Whether to draw lines connecting keypoints.
labels (bool): Whether to plot the label of bounding boxes. labels (bool): Whether to plot the label of bounding boxes.
boxes (bool): Whether to plot the bounding boxes. boxes (bool): Whether to plot the bounding boxes.
@ -226,12 +226,12 @@ class Results(SimpleClass):
# Plot Segment results # Plot Segment results
if pred_masks and show_masks: if pred_masks and show_masks:
if img_gpu is None: if im_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute( im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
2, 0, 1).flip(0).contiguous() / 255 2, 0, 1).flip(0).contiguous() / 255
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu) annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
# Plot Detect results # Plot Detect results
if pred_boxes and show_boxes: if pred_boxes and show_boxes:

@ -8,12 +8,12 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
Adjust bounding boxes to stick to image border if they are within a certain threshold. Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args: Args:
boxes: (n, 4) boxes (torch.Tensor): (n, 4)
image_shape: (height, width) image_shape (tuple): (height, width)
threshold: pixel threshold threshold (int): pixel threshold
Returns: Returns:
adjusted_boxes: adjusted bounding boxes adjusted_boxes (torch.Tensor): adjusted bounding boxes
""" """
# Image dimensions # Image dimensions
@ -32,11 +32,11 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args: Args:
box1: (4, ) box1 (torch.Tensor): (4, )
boxes: (n, 4) boxes (torch.Tensor): (n, 4)
Returns: Returns:
high_iou_indices: Indices of boxes with IoU > thres high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
""" """
boxes = adjust_bboxes_to_image_border(boxes, image_shape) boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections # obtain coordinates for intersections

@ -21,7 +21,8 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in (''
[f'yolo_nas_{k}.pt' for k in 'sml'] + \ [f'yolo_nas_{k}.pt' for k in 'sml'] + \
[f'sam_{k}.pt' for k in 'bl'] + \ [f'sam_{k}.pt' for k in 'bl'] + \
[f'FastSAM-{k}.pt' for k in 'sx'] + \ [f'FastSAM-{k}.pt' for k in 'sx'] + \
[f'rtdetr-{k}.pt' for k in 'lx'] [f'rtdetr-{k}.pt' for k in 'lx'] + \
['mobile_sam.pt']
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES] GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]

@ -20,6 +20,7 @@ def _ntuple(n):
return parse return parse
to_2tuple = _ntuple(2)
to_4tuple = _ntuple(4) to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom # `xyxy` means left top and right bottom

@ -92,7 +92,7 @@ def segment2box(segment, width=640, height=640):
4, dtype=segment.dtype) # xyxy 4, dtype=segment.dtype) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
""" """
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
(img1_shape) to the shape of a different image (img0_shape). (img1_shape) to the shape of a different image (img0_shape).
@ -103,6 +103,8 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
img0_shape (tuple): the shape of the target image, in the format of (height, width). img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images. calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns: Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
@ -115,6 +117,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
gain = ratio_pad[0][0] gain = ratio_pad[0][0]
pad = ratio_pad[1] pad = ratio_pad[1]
if padding:
boxes[..., [0, 2]] -= pad[0] # x padding boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding boxes[..., [1, 3]] -= pad[1] # y padding
boxes[..., :4] /= gain boxes[..., :4] /= gain
@ -552,7 +555,7 @@ def crop_mask(masks, boxes):
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
Args: Args:
masks (torch.Tensor): [h, w, n] tensor of masks masks (torch.Tensor): [n, h, w] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
Returns: Returns:
@ -634,18 +637,36 @@ def process_mask_native(protos, masks_in, bboxes, shape):
""" """
c, mh, mw = protos.shape # CHW c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new masks = scale_masks(masks[None], shape)[0] # CHW
pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(mh - pad[1]), int(mw - pad[0])
masks = masks[:, top:bottom, left:right]
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5) return masks.gt_(0.5)
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False): def scale_masks(masks, shape, padding=True):
"""
Rescale segment masks to shape.
Args:
masks (torch.Tensor): (N, C, H, W).
shape (tuple): Height and width.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
"""
mh, mw = masks.shape[2:]
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
if padding:
pad[0] /= 2
pad[1] /= 2
top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
masks = masks[..., top:bottom, left:right]
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
return masks
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
""" """
Rescale segment coordinates (xyxy) from img1_shape to img0_shape Rescale segment coordinates (xyxy) from img1_shape to img0_shape
@ -655,6 +676,8 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
img0_shape (tuple): the shape of the image that the segmentation is being applied to img0_shape (tuple): the shape of the image that the segmentation is being applied to
ratio_pad (tuple): the ratio of the image size to the padded image size. ratio_pad (tuple): the ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns: Returns:
coords (torch.Tensor): the segmented image. coords (torch.Tensor): the segmented image.
@ -666,6 +689,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
gain = ratio_pad[0][0] gain = ratio_pad[0][0]
pad = ratio_pad[1] pad = ratio_pad[1]
if padding:
coords[..., 0] -= pad[0] # x padding coords[..., 0] -= pad[0] # x padding
coords[..., 1] -= pad[1] # y padding coords[..., 1] -= pad[1] # y padding
coords[..., 0] /= gain coords[..., 0] /= gain

Loading…
Cancel
Save