diff --git a/docs/models/index.md b/docs/models/index.md index 611cad7fba..e841db26a9 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -17,6 +17,7 @@ In this documentation, we provide information on four major models: 5. [YOLOv7](./yolov7.md): Updated YOLO models released in 2022 by the authors of YOLOv4. 6. [YOLOv8](./yolov8.md): The latest version of the YOLO family, featuring enhanced capabilities such as instance segmentation, pose/keypoints estimation, and classification. 7. [Segment Anything Model (SAM)](./sam.md): Meta's Segment Anything Model (SAM). +7. [Mobile Segment Anything Model (MobileSAM)](./mobile-sam.md): MobileSAM for mobile applications by Kyung Hee University. 8. [Fast Segment Anything Model (FastSAM)](./fast-sam.md): FastSAM by Image & Video Analysis Group, Institute of Automation, Chinese Academy of Sciences. 9. [YOLO-NAS](./yolo-nas.md): YOLO Neural Architecture Search (NAS) Models. 10. [Realtime Detection Transformers (RT-DETR)](./rtdetr.md): Baidu's PaddlePaddle Realtime Detection Transformer (RT-DETR) models. @@ -44,4 +45,4 @@ model.info() # display model information model.train(data="coco128.yaml", epochs=100) # train the model ``` -For more details on each model, their supported tasks, modes, and performance, please visit their respective documentation pages linked above. \ No newline at end of file +For more details on each model, their supported tasks, modes, and performance, please visit their respective documentation pages linked above. diff --git a/docs/models/mobile-sam.md b/docs/models/mobile-sam.md new file mode 100644 index 0000000000..94bc83db43 --- /dev/null +++ b/docs/models/mobile-sam.md @@ -0,0 +1,99 @@ +--- +comments: true +description: MobileSAM is a lightweight adaptation of the Segment Anything Model (SAM) designed for mobile applications. It maintains the full functionality of the original SAM while significantly improving speed, making it suitable for CPU-only edge devices, such as mobile phones. +keywords: MobileSAM, Faster Segment Anything, Segment Anything, Segment Anything Model, SAM, Meta SAM, image segmentation, promptable segmentation, zero-shot performance, SA-1B dataset, advanced architecture, auto-annotation, Ultralytics, pre-trained models, SAM base, SAM large, instance segmentation, computer vision, AI, artificial intelligence, machine learning, data annotation, segmentation masks, detection model, YOLO detection model, bibtex, Meta AI +--- + +![MobileSAM Logo](https://github.com/ChaoningZhang/MobileSAM/blob/master/assets/logo2.png?raw=true) + +# Faster Segment Anything (MobileSAM) + +The MobileSAM paper is now available on [ResearchGate](https://www.researchgate.net/publication/371851844_Faster_Segment_Anything_Towards_Lightweight_SAM_for_Mobile_Applications) and [arXiv](https://arxiv.org/pdf/2306.14289.pdf). The most recent version will initially appear on ResearchGate due to the delayed content update on arXiv. + +A demonstration of MobileSAM running on a CPU can be accessed at this [demo link](https://huggingface.co/spaces/dhkim2810/MobileSAM). The performance on a Mac i5 CPU takes approximately 3 seconds. On the Hugging Face demo, the interface and lower-performance CPUs contribute to a slower response, but it continues to function effectively. + +MobileSAM is implemented in various projects including [Grounding-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything), [AnyLabeling](https://github.com/vietanhdev/anylabeling), and [SegmentAnythingin3D](https://github.com/Jumpat/SegmentAnythingin3D). + +MobileSAM is trained on a single GPU with a 100k dataset (1% of the original images) in less than a day. The code for this training will be made available in the future. + +## Adapting from SAM to MobileSAM + +Since MobileSAM retains the same pipeline as the original SAM, we have incorporated the original's pre-processing, post-processing, and all other interfaces. Consequently, those currently using the original SAM can transition to MobileSAM with minimal effort. + +MobileSAM performs comparably to the original SAM and retains the same pipeline except for a change in the image encoder. Specifically, we replace the original heavyweight ViT-H encoder (632M) with a smaller Tiny-ViT (5M). On a single GPU, MobileSAM operates at about 12ms per image: 8ms on the image encoder and 4ms on the mask decoder. + +The following table provides a comparison of ViT-based image encoders: + +| Image Encoder | Original SAM | MobileSAM | +|---------------|--------------|-----------| +| Parameters | 611M | 5M | +| Speed | 452ms | 8ms | + +Both the original SAM and MobileSAM utilize the same prompt-guided mask decoder: + +| Mask Decoder | Original SAM | MobileSAM | +|--------------|--------------|-----------| +| Parameters | 3.876M | 3.876M | +| Speed | 4ms | 4ms | + +Here is the comparison of the whole pipeline: + +| Whole Pipeline (Enc+Dec) | Original SAM | MobileSAM | +|--------------------------|--------------|-----------| +| Parameters | 615M | 9.66M | +| Speed | 456ms | 12ms | + +The performance of MobileSAM and the original SAM are demonstrated using both a point and a box as prompts. + +![Image with Point as Prompt](https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/assets/mask_box.jpg?raw=true) + +![Image with Box as Prompt](https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/assets/mask_box.jpg?raw=true) + +With its superior performance, MobileSAM is approximately 5 times smaller and 7 times faster than the current FastSAM. More details are available at the [MobileSAM project page](https://github.com/ChaoningZhang/MobileSAM). + +## Testing MobileSAM in Ultralytics + +Just like the original SAM, we offer a straightforward testing method in Ultralytics, including modes for both Point and Box prompts. + +### Model Download + +You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt). + +### Point Prompt + +```python +from ultralytics import SAM + +# Load the model +model = SAM('mobile_sam.pt') + +# Predict a segment based on a point prompt +model.predict('ultralytics/assets/zidane.jpg', points=[900, 370], labels=[1]) +``` + +### Box Prompt + +```python +from ultralytics import SAM + +# Load the model +model = SAM('mobile_sam.pt') + +# Predict a segment based on a box prompt +model.predict('ultralytics/assets/zidane.jpg', bboxes=[439, 437, 524, 709]) +``` + +We have implemented `MobileSAM` and `SAM` using the same API. For more usage information, please see the [SAM page](./sam.md). + +### Citing MobileSAM + +If you find MobileSAM useful in your research or development work, please consider citing our paper: + +```bibtex +@article{mobile_sam, + title={Faster Segment Anything: Towards Lightweight SAM for Mobile Applications}, + author={Zhang, Chaoning and Han, Dongshen and Qiao, Yu and Kim, Jung Uk and Bae, Sung Ho and Lee, Seungkyu and Hong, Choong Seon}, + journal={arXiv preprint arXiv:2306.14289}, + year={2023} +} +``` diff --git a/docs/models/sam.md b/docs/models/sam.md index 79bbd01aae..e9f9ac035a 100644 --- a/docs/models/sam.md +++ b/docs/models/sam.md @@ -30,9 +30,33 @@ For an in-depth look at the Segment Anything Model and the SA-1B dataset, please The Segment Anything Model can be employed for a multitude of downstream tasks that go beyond its training data. This includes edge detection, object proposal generation, instance segmentation, and preliminary text-to-mask prediction. With prompt engineering, SAM can swiftly adapt to new tasks and data distributions in a zero-shot manner, establishing it as a versatile and potent tool for all your image segmentation needs. -!!! example "SAM prediction example" +### SAM prediction example - Device is determined automatically. If a GPU is available then it will be used, otherwise inference will run on CPU. +!!! example "Segment with prompts" + + Segment image with given prompts. + + === "Python" + + ```python + from ultralytics import SAM + + # Load a model + model = SAM('sam_b.pt') + + # Display model information (optional) + model.info() + + # Run inference with bboxes prompt + model('ultralytics/assets/zidane.jpg', bboxes=[439, 437, 524, 709]) + + # Run inference with points prompt + model.predict('ultralytics/assets/zidane.jpg', points=[900, 370], labels=[1]) + ``` + +!!! example "Segment everything" + + Segment the whole image. === "Python" @@ -45,7 +69,7 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t # Display model information (optional) model.info() - # Run inference with the model + # Run inference model('path/to/image.jpg') ``` === "CLI" @@ -55,6 +79,48 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t yolo predict model=sam_b.pt source=path/to/image.jpg ``` +- The logic here is to segment the whole image if you don't pass any prompts(bboxes/points/masks). + +!!! example "SAMPredictor example" + + This way you can set image once and run prompts inference multiple times without running image encoder multiple times. + + === "Prompt inference" + + ```python + from ultralytics.vit.sam import Predictor as SAMPredictor + + # Create SAMPredictor + overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024, model="mobile_sam.pt") + predictor = SAMPredictor(overrides=overrides) + + # Set image + predictor.set_image("ultralytics/assets/zidane.jpg") # set with image file + predictor.set_image(cv2.imread("ultralytics/assets/zidane.jpg")) # set with np.ndarray + results = predictor(bboxes=[439, 437, 524, 709]) + results = predictor(points=[900, 370], labels=[1]) + # Reset image + predictor.reset_image() + ``` + + Segment everything with additional args. + + === "Segment everything" + + ```python + from ultralytics.vit.sam import Predictor as SAMPredictor + + # Create SAMPredictor + overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024, model="mobile_sam.pt") + predictor = SAMPredictor(overrides=overrides) + + # segment with additional args + results = predictor(source="ultralytics/assets/zidane.jpg", crop_n_layers=1, points_stride=64) + + ``` + +- More additional args for `Segment everything` see [`Predictor/generate` Reference](../reference/vit/sam/predict.md). + ## Available Models and Supported Tasks | Model Type | Pre-trained Weights | Tasks Supported | @@ -76,21 +142,33 @@ Here we compare Meta's smallest SAM model, SAM-b, with Ultralytics smallest segm | Model | Size | Parameters | Speed (CPU) | |------------------------------------------------|----------------------------|------------------------|-------------------------| -| Meta's SAM-b | 358 MB | 94.7 M | 51096 ms | -| Ultralytics [YOLOv8n-seg](../tasks/segment.md) | **6.7 MB** (53.4x smaller) | **3.4 M** (27.9x less) | **59 ms** (866x faster) | +| Meta's SAM-b | 358 MB | 94.7 M | 51096 ms/im | +| [MobileSAM](mobile-sam.md) | 40.7 MB | 10.1 M | 46122 ms/im | +| [FastSAM-s](fast-sam.md) with YOLOv8 backbone | 23.7 MB | 11.8 M | 115 ms/im | +| Ultralytics [YOLOv8n-seg](../tasks/segment.md) | **6.7 MB** (53.4x smaller) | **3.4 M** (27.9x less) | **59 ms/im** (866x faster) | -This comparison shows the order-of-magnitude differences in the model sizes and speeds. Whereas SAM presents unique capabilities for automatic segmenting, it is not a direct competitor to YOLOv8 segment models, which are smaller, faster and more efficient since they are dedicated to more targeted use cases. +This comparison shows the order-of-magnitude differences in the model sizes and speeds between models. Whereas SAM presents unique capabilities for automatic segmenting, it is not a direct competitor to YOLOv8 segment models, which are smaller, faster and more efficient. -To reproduce this test: +Tests run on a 2023 Apple M2 Macbook with 16GB of RAM. To reproduce this test: ```python -from ultralytics import SAM, YOLO +from ultralytics import FastSAM, SAM, YOLO # Profile SAM-b model = SAM('sam_b.pt') model.info() model('ultralytics/assets') +# Profile MobileSAM +model = SAM('mobile_sam.pt') +model.info() +model('ultralytics/assets') + +# Profile FastSAM-s +model = FastSAM('FastSAM-s.pt') +model.info() +model('ultralytics/assets') + # Profile YOLOv8n-seg model = YOLO('yolov8n-seg.pt') model.info() @@ -140,4 +218,4 @@ If you find SAM useful in your research or development work, please consider cit We would like to express our gratitude to Meta AI for creating and maintaining this valuable resource for the computer vision community. -*keywords: Segment Anything, Segment Anything Model, SAM, Meta SAM, image segmentation, promptable segmentation, zero-shot performance, SA-1B dataset, advanced architecture, auto-annotation, Ultralytics, pre-trained models, SAM base, SAM large, instance segmentation, computer vision, AI, artificial intelligence, machine learning, data annotation, segmentation masks, detection model, YOLO detection model, bibtex, Meta AI.* \ No newline at end of file +*keywords: Segment Anything, Segment Anything Model, SAM, Meta SAM, image segmentation, promptable segmentation, zero-shot performance, SA-1B dataset, advanced architecture, auto-annotation, Ultralytics, pre-trained models, SAM base, SAM large, instance segmentation, computer vision, AI, artificial intelligence, machine learning, data annotation, segmentation masks, detection model, YOLO detection model, bibtex, Meta AI.* diff --git a/docs/reference/vit/sam/autosize.md b/docs/reference/vit/sam/autosize.md deleted file mode 100644 index ca84d37f6c..0000000000 --- a/docs/reference/vit/sam/autosize.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -description: Learn how to use the ResizeLongestSide module in Ultralytics YOLO for automatic image resizing. Resize your images with ease. -keywords: ResizeLongestSide, Ultralytics YOLO, automatic image resizing, image resizing ---- - -## ResizeLongestSide ---- -### ::: ultralytics.vit.sam.autosize.ResizeLongestSide -

diff --git a/docs/reference/vit/sam/build.md b/docs/reference/vit/sam/build.md index 6c3962112e..c44e48b517 100644 --- a/docs/reference/vit/sam/build.md +++ b/docs/reference/vit/sam/build.md @@ -18,6 +18,11 @@ keywords: SAM, VIT, computer vision models, build SAM models, build VIT models, ### ::: ultralytics.vit.sam.build.build_sam_vit_b

+## build_mobile_sam +--- +### ::: ultralytics.vit.sam.build.build_mobile_sam +

+ ## _build_sam --- ### ::: ultralytics.vit.sam.build._build_sam diff --git a/docs/reference/vit/sam/modules/mask_generator.md b/docs/reference/vit/sam/modules/mask_generator.md deleted file mode 100644 index beec1d3f02..0000000000 --- a/docs/reference/vit/sam/modules/mask_generator.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -description: Learn about the SamAutomaticMaskGenerator module in Ultralytics YOLO, an automatic mask generator for image segmentation. -keywords: SamAutomaticMaskGenerator, Ultralytics YOLO, automatic mask generator, image segmentation ---- - -## SamAutomaticMaskGenerator ---- -### ::: ultralytics.vit.sam.modules.mask_generator.SamAutomaticMaskGenerator -

diff --git a/docs/reference/vit/sam/modules/prompt_predictor.md b/docs/reference/vit/sam/modules/prompt_predictor.md deleted file mode 100644 index 00de169ef6..0000000000 --- a/docs/reference/vit/sam/modules/prompt_predictor.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -description: Learn about PromptPredictor - a module in Ultralytics VIT SAM that predicts image captions based on prompts. Get started today!. -keywords: PromptPredictor, Ultralytics, YOLO, VIT SAM, image captioning, deep learning, computer vision ---- - -## PromptPredictor ---- -### ::: ultralytics.vit.sam.modules.prompt_predictor.PromptPredictor -

diff --git a/docs/reference/vit/sam/modules/tiny_encoder.md b/docs/reference/vit/sam/modules/tiny_encoder.md new file mode 100644 index 0000000000..eb20355ff1 --- /dev/null +++ b/docs/reference/vit/sam/modules/tiny_encoder.md @@ -0,0 +1,59 @@ +--- +description: Learn about the Conv2d_BN, MBConv, ConvLayer, Attention, BasicLayer, and TinyViT modules. +keywords: Conv2d_BN, MBConv, ConvLayer, Attention, BasicLayer, TinyViT +--- + +## Conv2d_BN +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.Conv2d_BN +

+ +## PatchEmbed +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.PatchEmbed +

+ +## MBConv +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.MBConv +

+ +## PatchMerging +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.PatchMerging +

+ +## ConvLayer +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.ConvLayer +

+ +## Mlp +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.Mlp +

+ +## Attention +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.Attention +

+ +## TinyViTBlock +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.TinyViTBlock +

+ +## BasicLayer +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.BasicLayer +

+ +## LayerNorm2d +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.LayerNorm2d +

+ +## TinyViT +--- +### ::: ultralytics.vit.sam.modules.tiny_encoder.TinyViT +

diff --git a/docs/reference/yolo/fastsam/model.md b/docs/reference/yolo/fastsam/model.md new file mode 100644 index 0000000000..b84e9946aa --- /dev/null +++ b/docs/reference/yolo/fastsam/model.md @@ -0,0 +1,9 @@ +--- +description: Learn how to use FastSAM in Ultralytics YOLO to improve object detection accuracy and speed. +keywords: FastSAM, object detection, accuracy, speed, Ultralytics YOLO +--- + +## FastSAM +--- +### ::: ultralytics.yolo.fastsam.model.FastSAM +

diff --git a/docs/reference/yolo/fastsam/predict.md b/docs/reference/yolo/fastsam/predict.md new file mode 100644 index 0000000000..377ae25a55 --- /dev/null +++ b/docs/reference/yolo/fastsam/predict.md @@ -0,0 +1,9 @@ +--- +description: FastSAMPredictor API reference and usage guide for the Ultralytics YOLO object detection library. +keywords: FastSAMPredictor, API, reference, usage, guide, Ultralytics, YOLO, object detection, library +--- + +## FastSAMPredictor +--- +### ::: ultralytics.yolo.fastsam.predict.FastSAMPredictor +

diff --git a/docs/reference/yolo/fastsam/prompt.md b/docs/reference/yolo/fastsam/prompt.md new file mode 100644 index 0000000000..e6fdc6dbf5 --- /dev/null +++ b/docs/reference/yolo/fastsam/prompt.md @@ -0,0 +1,9 @@ +--- +description: Learn how to use FastSAMPrompt in Ultralytics YOLO for fast and efficient object detection and tracking. +keywords: FastSAMPrompt, Ultralytics YOLO, object detection, tracking, fast, efficient +--- + +## FastSAMPrompt +--- +### ::: ultralytics.yolo.fastsam.prompt.FastSAMPrompt +

diff --git a/docs/reference/yolo/fastsam/utils.md b/docs/reference/yolo/fastsam/utils.md new file mode 100644 index 0000000000..6031bbc8e4 --- /dev/null +++ b/docs/reference/yolo/fastsam/utils.md @@ -0,0 +1,14 @@ +--- +description: Learn how to adjust bounding boxes to the image border in Ultralytics YOLO framework. Improve object detection accuracy by accounting for image borders. +keywords: adjust_bboxes_to_image_border, Ultralytics YOLO, object detection, bounding boxes, image border +--- + +## adjust_bboxes_to_image_border +--- +### ::: ultralytics.yolo.fastsam.utils.adjust_bboxes_to_image_border +

+ +## bbox_iou +--- +### ::: ultralytics.yolo.fastsam.utils.bbox_iou +

diff --git a/docs/reference/yolo/fastsam/val.md b/docs/reference/yolo/fastsam/val.md new file mode 100644 index 0000000000..e4bb5476c8 --- /dev/null +++ b/docs/reference/yolo/fastsam/val.md @@ -0,0 +1,9 @@ +--- +description: Learn about the FastSAMValidator module in Ultralytics YOLO. Validate and evaluate Segment Anything Model (SAM) datasets for object detection models with ease. +keywords: FastSAMValidator, Ultralytics YOLO, SAM datasets, object detection, validation, evaluation +--- + +## FastSAMValidator +--- +### ::: ultralytics.yolo.fastsam.val.FastSAMValidator +

diff --git a/docs/reference/yolo/utils/ops.md b/docs/reference/yolo/utils/ops.md index f35584a07e..ce4e4d5937 100644 --- a/docs/reference/yolo/utils/ops.md +++ b/docs/reference/yolo/utils/ops.md @@ -123,6 +123,11 @@ keywords: Ultralytics, YOLO, Utils Ops, Functions, coco80_to_coco91_class, scale ### ::: ultralytics.yolo.utils.ops.process_mask_native

+## scale_masks +--- +### ::: ultralytics.yolo.utils.ops.scale_masks +

+ ## scale_coords --- ### ::: ultralytics.yolo.utils.ops.scale_coords diff --git a/mkdocs.yml b/mkdocs.yml index 3295fb9220..75f8a74aba 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -168,6 +168,7 @@ nav: - YOLOv7: models/yolov7.md - YOLOv8: models/yolov8.md - SAM (Segment Anything Model): models/sam.md + - MobileSAM (Mobile Segment Anything Model): models/mobile-sam.md - FastSAM (Fast Segment Anything Model): models/fast-sam.md - YOLO-NAS (Neural Architecture Search): models/yolo-nas.md - RT-DETR (Realtime Detection Transformer): models/rtdetr.md @@ -282,15 +283,13 @@ nav: - val: reference/vit/rtdetr/val.md - sam: - amg: reference/vit/sam/amg.md - - autosize: reference/vit/sam/autosize.md - build: reference/vit/sam/build.md - model: reference/vit/sam/model.md - modules: - decoders: reference/vit/sam/modules/decoders.md - encoders: reference/vit/sam/modules/encoders.md - - mask_generator: reference/vit/sam/modules/mask_generator.md - - prompt_predictor: reference/vit/sam/modules/prompt_predictor.md - sam: reference/vit/sam/modules/sam.md + - tiny_encoder: reference/vit/sam/modules/tiny_encoder.md - transformer: reference/vit/sam/modules/transformer.md - predict: reference/vit/sam/predict.md - utils: @@ -319,6 +318,12 @@ nav: - results: reference/yolo/engine/results.md - trainer: reference/yolo/engine/trainer.md - validator: reference/yolo/engine/validator.md + - fastsam: + - model: reference/yolo/fastsam/model.md + - predict: reference/yolo/fastsam/predict.md + - prompt: reference/yolo/fastsam/prompt.md + - utils: reference/yolo/fastsam/utils.md + - val: reference/yolo/fastsam/val.md - nas: - model: reference/yolo/nas/model.md - predict: reference/yolo/nas/predict.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 9f64b83122..a4d865298c 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.133' +__version__ = '8.0.134' from ultralytics.hub import start from ultralytics.vit.rtdetr import RTDETR diff --git a/ultralytics/vit/sam/__init__.py b/ultralytics/vit/sam/__init__.py index b47c04364e..35f4efa86a 100644 --- a/ultralytics/vit/sam/__init__.py +++ b/ultralytics/vit/sam/__init__.py @@ -1,5 +1,8 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from .build import build_sam # noqa -from .model import SAM # noqa -from .modules.prompt_predictor import PromptPredictor # noqa +from .model import SAM +from .predict import Predictor + +# from .build import build_sam + +__all__ = 'SAM', 'Predictor' # tuple or list diff --git a/ultralytics/vit/sam/autosize.py b/ultralytics/vit/sam/autosize.py deleted file mode 100644 index ef33644540..0000000000 --- a/ultralytics/vit/sam/autosize.py +++ /dev/null @@ -1,94 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from copy import deepcopy -from typing import Tuple - -import numpy as np -import torch -from torch.nn import functional as F -from torchvision.transforms.functional import resize, to_pil_image # type: ignore - - -class ResizeLongestSide: - """ - Resizes images to the longest side 'target_length', as well as provides - methods for resizing coordinates and boxes. Provides methods for - transforming both numpy array and batched torch tensors. - """ - - def __init__(self, target_length: int) -> None: - self.target_length = target_length - - def apply_image(self, image: np.ndarray) -> np.ndarray: - """ - Expects a numpy array with shape HxWxC in uint8 format. - """ - target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) - return np.array(resize(to_pil_image(image), target_size)) - - def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: - """ - Expects a numpy array of length 2 in the final dimension. Requires the - original image size in (H, W) format. - """ - old_h, old_w = original_size - new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) - coords = deepcopy(coords).astype(float) - coords[..., 0] = coords[..., 0] * (new_w / old_w) - coords[..., 1] = coords[..., 1] * (new_h / old_h) - return coords - - def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: - """ - Expects a numpy array shape Bx4. Requires the original image size - in (H, W) format. - """ - boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) - return boxes.reshape(-1, 4) - - def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: - """ - Expects batched images with shape BxCxHxW and float format. This - transformation may not exactly match apply_image. apply_image is - the transformation expected by the model. - """ - # Expects an image in BCHW format. May not exactly match apply_image. - target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) - return F.interpolate(image, target_size, mode='bilinear', align_corners=False, antialias=True) - - def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: - """ - Expects a torch tensor with length 2 in the last dimension. Requires the - original image size in (H, W) format. - """ - old_h, old_w = original_size - new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) - coords = deepcopy(coords).to(torch.float) - coords[..., 0] = coords[..., 0] * (new_w / old_w) - coords[..., 1] = coords[..., 1] * (new_h / old_h) - return coords - - def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: - """ - Expects a torch tensor with shape Bx4. Requires the original image - size in (H, W) format. - """ - boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) - return boxes.reshape(-1, 4) - - @staticmethod - def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: - """ - Compute the output size given input size and target long side length. - """ - scale = long_side_length * 1.0 / max(oldh, oldw) - newh, neww = oldh * scale, oldw * scale - neww = int(neww + 0.5) - newh = int(newh + 0.5) - return (newh, neww) diff --git a/ultralytics/vit/sam/build.py b/ultralytics/vit/sam/build.py index 73b1a03a3a..3572c2e939 100644 --- a/ultralytics/vit/sam/build.py +++ b/ultralytics/vit/sam/build.py @@ -14,6 +14,7 @@ from ...yolo.utils.downloads import attempt_download_asset from .modules.decoders import MaskDecoder from .modules.encoders import ImageEncoderViT, PromptEncoder from .modules.sam import Sam +from .modules.tiny_encoder import TinyViT from .modules.transformer import TwoWayTransformer @@ -50,33 +51,60 @@ def build_sam_vit_b(checkpoint=None): ) -def _build_sam( - encoder_embed_dim, - encoder_depth, - encoder_num_heads, - encoder_global_attn_indexes, - checkpoint=None, -): +def build_mobile_sam(checkpoint=None): + """Build and return Mobile Segment Anything Model (Mobile-SAM).""" + return _build_sam( + encoder_embed_dim=[64, 128, 160, 320], + encoder_depth=[2, 2, 6, 2], + encoder_num_heads=[2, 4, 5, 10], + encoder_global_attn_indexes=None, + mobile_sam=True, + checkpoint=checkpoint, + ) + + +def _build_sam(encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + mobile_sam=False): """Builds the selected SAM model architecture.""" prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size + image_encoder = (TinyViT( + img_size=1024, + in_chans=3, + num_classes=1000, + embed_dims=encoder_embed_dim, + depths=encoder_depth, + num_heads=encoder_num_heads, + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ) if mobile_sam else ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + )) sam = Sam( - image_encoder=ImageEncoderViT( - depth=encoder_depth, - embed_dim=encoder_embed_dim, - img_size=image_size, - mlp_ratio=4, - norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), - num_heads=encoder_num_heads, - patch_size=vit_patch_size, - qkv_bias=True, - use_rel_pos=True, - global_attn_indexes=encoder_global_attn_indexes, - window_size=14, - out_chans=prompt_embed_dim, - ), + image_encoder=image_encoder, prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), @@ -98,20 +126,22 @@ def _build_sam( pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) - sam.eval() if checkpoint is not None: checkpoint = attempt_download_asset(checkpoint) with open(checkpoint, 'rb') as f: state_dict = torch.load(f) sam.load_state_dict(state_dict) + sam.eval() + # sam.load_state_dict(torch.load(checkpoint), strict=True) + # sam.eval() return sam sam_model_map = { - # "default": build_sam_vit_h, 'sam_h.pt': build_sam_vit_h, 'sam_l.pt': build_sam_vit_l, - 'sam_b.pt': build_sam_vit_b, } + 'sam_b.pt': build_sam_vit_b, + 'mobile_sam.pt': build_mobile_sam, } def build_sam(ckpt='sam_b.pt'): diff --git a/ultralytics/vit/sam/model.py b/ultralytics/vit/sam/model.py index 83861f4b9c..925328ef78 100644 --- a/ultralytics/vit/sam/model.py +++ b/ultralytics/vit/sam/model.py @@ -4,8 +4,8 @@ SAM model interface """ from ultralytics.yolo.cfg import get_cfg +from ultralytics.yolo.utils.torch_utils import model_info -from ...yolo.utils.torch_utils import model_info from .build import build_sam from .predict import Predictor @@ -20,16 +20,16 @@ class SAM: self.task = 'segment' # required self.predictor = None # reuse predictor - def predict(self, source, stream=False, **kwargs): + def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): """Predicts and returns segmentation masks for given image or video source.""" - overrides = dict(conf=0.25, task='segment', mode='predict') + overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) overrides.update(kwargs) # prefer kwargs if not self.predictor: self.predictor = Predictor(overrides=overrides) self.predictor.setup_model(model=self.model) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) - return self.predictor(source, stream=stream) + return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels) def train(self, **kwargs): """Function trains models but raises an error as SAM models do not support training.""" @@ -39,9 +39,9 @@ class SAM: """Run validation given dataset.""" raise NotImplementedError("SAM models don't support validation") - def __call__(self, source=None, stream=False, **kwargs): + def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): """Calls the 'predict' function with given arguments to perform object detection.""" - return self.predict(source, stream, **kwargs) + return self.predict(source, stream, bboxes, points, labels, **kwargs) def __getattr__(self, attr): """Raises error if object has no requested attribute.""" diff --git a/ultralytics/vit/sam/modules/mask_generator.py b/ultralytics/vit/sam/modules/mask_generator.py deleted file mode 100644 index 8c1e00ea17..0000000000 --- a/ultralytics/vit/sam/modules/mask_generator.py +++ /dev/null @@ -1,353 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import torch -from torchvision.ops.boxes import batched_nms, box_area # type: ignore - -from ..amg import (MaskData, area_from_rle, batch_iterator, batched_mask_to_box, box_xyxy_to_xywh, - build_all_layer_point_grids, calculate_stability_score, coco_encode_rle, generate_crop_boxes, - is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, - uncrop_masks, uncrop_points) -from .prompt_predictor import PromptPredictor -from .sam import Sam - - -class SamAutomaticMaskGenerator: - - def __init__( - self, - model: Sam, - points_per_side: Optional[int] = 32, - points_per_batch: int = 64, - pred_iou_thresh: float = 0.88, - stability_score_thresh: float = 0.95, - stability_score_offset: float = 1.0, - box_nms_thresh: float = 0.7, - crop_n_layers: int = 0, - crop_nms_thresh: float = 0.7, - crop_overlap_ratio: float = 512 / 1500, - crop_n_points_downscale_factor: int = 1, - point_grids: Optional[List[np.ndarray]] = None, - min_mask_region_area: int = 0, - output_mode: str = 'binary_mask', - ) -> None: - """ - Using a SAM model, generates masks for the entire image. - Generates a grid of point prompts over the image, then filters - low quality and duplicate masks. The default settings are chosen - for SAM with a ViT-H backbone. - - Arguments: - model (Sam): The SAM model to use for mask prediction. - points_per_side (int, None): The number of points to be sampled - along one side of the image. The total number of points is - points_per_side**2. If None, 'point_grids' must provide explicit - point sampling. - points_per_batch (int): Sets the number of points run simultaneously - by the model. Higher numbers may be faster but use more GPU memory. - pred_iou_thresh (float): A filtering threshold in [0,1], using the - model's predicted mask quality. - stability_score_thresh (float): A filtering threshold in [0,1], using - the stability of the mask under changes to the cutoff used to binarize - the model's mask predictions. - stability_score_offset (float): The amount to shift the cutoff when - calculated the stability score. - box_nms_thresh (float): The box IoU cutoff used by non-maximal - suppression to filter duplicate masks. - crop_n_layers (int): If >0, mask prediction will be run again on - crops of the image. Sets the number of layers to run, where each - layer has 2**i_layer number of image crops. - crop_nms_thresh (float): The box IoU cutoff used by non-maximal - suppression to filter duplicate masks between different crops. - crop_overlap_ratio (float): Sets the degree to which crops overlap. - In the first crop layer, crops will overlap by this fraction of - the image length. Later layers with more crops scale down this overlap. - crop_n_points_downscale_factor (int): The number of points-per-side - sampled in layer n is scaled down by crop_n_points_downscale_factor**n. - point_grids (list(np.ndarray), None): A list over explicit grids - of points used for sampling, normalized to [0,1]. The nth grid in the - list is used in the nth crop layer. Exclusive with points_per_side. - min_mask_region_area (int): If >0, postprocessing will be applied - to remove disconnected regions and holes in masks with area smaller - than min_mask_region_area. Requires opencv. - output_mode (str): The form masks are returned in. Can be 'binary_mask', - 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. - For large resolutions, 'binary_mask' may consume large amounts of - memory. - """ - - assert (points_per_side is None) != (point_grids is None), \ - 'Exactly one of points_per_side or point_grid must be provided.' - if points_per_side is not None: - self.point_grids = build_all_layer_point_grids( - points_per_side, - crop_n_layers, - crop_n_points_downscale_factor, - ) - elif point_grids is not None: - self.point_grids = point_grids - else: - raise ValueError("Can't have both points_per_side and point_grid be None.") - - assert output_mode in {'binary_mask', 'uncompressed_rle', 'coco_rle'}, f'Unknown output_mode {output_mode}.' - if output_mode == 'coco_rle': - from pycocotools import mask as mask_utils # type: ignore # noqa: F401 - - if min_mask_region_area > 0: - import cv2 # type: ignore # noqa: F401 - - self.predictor = PromptPredictor(model) - self.points_per_batch = points_per_batch - self.pred_iou_thresh = pred_iou_thresh - self.stability_score_thresh = stability_score_thresh - self.stability_score_offset = stability_score_offset - self.box_nms_thresh = box_nms_thresh - self.crop_n_layers = crop_n_layers - self.crop_nms_thresh = crop_nms_thresh - self.crop_overlap_ratio = crop_overlap_ratio - self.crop_n_points_downscale_factor = crop_n_points_downscale_factor - self.min_mask_region_area = min_mask_region_area - self.output_mode = output_mode - - # TODO: Temporary implementation for compatibility - def __call__(self, image: np.ndarray, augment=False, visualize=False) -> List[Dict[str, Any]]: - return self.generate(image) - - @torch.no_grad() - def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: - """ - Generates masks for the given image. - - Arguments: - image (np.ndarray): The image to generate masks for, in HWC uint8 format. - - Returns: - list(dict(str, any)): A list over records for masks. Each record is a dict containing the following keys: - segmentation (dict(str, any), np.ndarray): The mask. If - output_mode='binary_mask', is an array of shape HW. Otherwise, - is a dictionary containing the RLE. - bbox (list(float)): The box around the mask, in XYWH format. - area (int): The area in pixels of the mask. - predicted_iou (float): The model's own prediction of the mask's - quality. This is filtered by the pred_iou_thresh parameter. - point_coords (list(list(float))): The point coordinates input - to the model to generate this mask. - stability_score (float): A measure of the mask's quality. This - is filtered on using the stability_score_thresh parameter. - crop_box (list(float)): The crop of the image used to generate - the mask, given in XYWH format. - """ - - # Generate masks - mask_data = self._generate_masks(image) - - # Filter small disconnected regions and holes in masks - if self.min_mask_region_area > 0: - mask_data = self.postprocess_small_regions( - mask_data, - self.min_mask_region_area, - max(self.box_nms_thresh, self.crop_nms_thresh), - ) - - # Encode masks - if self.output_mode == 'coco_rle': - mask_data['segmentations'] = [coco_encode_rle(rle) for rle in mask_data['rles']] - elif self.output_mode == 'binary_mask': - mask_data['segmentations'] = [rle_to_mask(rle) for rle in mask_data['rles']] - else: - mask_data['segmentations'] = mask_data['rles'] - - # Write mask records - curr_anns = [] - for idx in range(len(mask_data['segmentations'])): - ann = { - 'segmentation': mask_data['segmentations'][idx], - 'area': area_from_rle(mask_data['rles'][idx]), - 'bbox': box_xyxy_to_xywh(mask_data['boxes'][idx]).tolist(), - 'predicted_iou': mask_data['iou_preds'][idx].item(), - 'point_coords': [mask_data['points'][idx].tolist()], - 'stability_score': mask_data['stability_score'][idx].item(), - 'crop_box': box_xyxy_to_xywh(mask_data['crop_boxes'][idx]).tolist(), } - curr_anns.append(ann) - - return curr_anns - - def _generate_masks(self, image: np.ndarray) -> MaskData: - orig_size = image.shape[:2] - crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio) - - # Iterate over image crops - data = MaskData() - for crop_box, layer_idx in zip(crop_boxes, layer_idxs): - crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) - data.cat(crop_data) - - # Remove duplicate masks between crops - if len(crop_boxes) > 1: - # Prefer masks from smaller crops - scores = 1 / box_area(data['crop_boxes']) - scores = scores.to(data['boxes'].device) - keep_by_nms = batched_nms( - data['boxes'].float(), - scores, - torch.zeros_like(data['boxes'][:, 0]), # categories - iou_threshold=self.crop_nms_thresh, - ) - data.filter(keep_by_nms) - - data.to_numpy() - return data - - def _process_crop( - self, - image: np.ndarray, - crop_box: List[int], - crop_layer_idx: int, - orig_size: Tuple[int, ...], - ) -> MaskData: - # Crop the image and calculate embeddings - x0, y0, x1, y1 = crop_box - cropped_im = image[y0:y1, x0:x1, :] - cropped_im_size = cropped_im.shape[:2] - self.predictor.set_image(cropped_im) - - # Get points for this crop - points_scale = np.array(cropped_im_size)[None, ::-1] - points_for_image = self.point_grids[crop_layer_idx] * points_scale - - # Generate masks for this crop in batches - data = MaskData() - for (points, ) in batch_iterator(self.points_per_batch, points_for_image): - batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) - data.cat(batch_data) - del batch_data - self.predictor.reset_image() - - # Remove duplicates within this crop. - keep_by_nms = batched_nms( - data['boxes'].float(), - data['iou_preds'], - torch.zeros_like(data['boxes'][:, 0]), # categories - iou_threshold=self.box_nms_thresh, - ) - data.filter(keep_by_nms) - - # Return to the original image frame - data['boxes'] = uncrop_boxes_xyxy(data['boxes'], crop_box) - data['points'] = uncrop_points(data['points'], crop_box) - data['crop_boxes'] = torch.tensor([crop_box for _ in range(len(data['rles']))]) - - return data - - def _process_batch( - self, - points: np.ndarray, - im_size: Tuple[int, ...], - crop_box: List[int], - orig_size: Tuple[int, ...], - ) -> MaskData: - orig_h, orig_w = orig_size - - # Run model on this batch - transformed_points = self.predictor.transform.apply_coords(points, im_size) - in_points = torch.as_tensor(transformed_points, device=self.predictor.device) - in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) - masks, iou_preds, _ = self.predictor.predict_torch( - in_points[:, None, :], - in_labels[:, None], - multimask_output=True, - return_logits=True, - ) - - # Serialize predictions and store in MaskData - data = MaskData( - masks=masks.flatten(0, 1), - iou_preds=iou_preds.flatten(0, 1), - points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), - ) - del masks - - # Filter by predicted IoU - if self.pred_iou_thresh > 0.0: - keep_mask = data['iou_preds'] > self.pred_iou_thresh - data.filter(keep_mask) - - # Calculate stability score - data['stability_score'] = calculate_stability_score(data['masks'], self.predictor.model.mask_threshold, - self.stability_score_offset) - if self.stability_score_thresh > 0.0: - keep_mask = data['stability_score'] >= self.stability_score_thresh - data.filter(keep_mask) - - # Threshold masks and calculate boxes - data['masks'] = data['masks'] > self.predictor.model.mask_threshold - data['boxes'] = batched_mask_to_box(data['masks']) - - # Filter boxes that touch crop boundaries - keep_mask = ~is_box_near_crop_edge(data['boxes'], crop_box, [0, 0, orig_w, orig_h]) - if not torch.all(keep_mask): - data.filter(keep_mask) - - # Compress to RLE - data['masks'] = uncrop_masks(data['masks'], crop_box, orig_h, orig_w) - data['rles'] = mask_to_rle_pytorch(data['masks']) - del data['masks'] - - return data - - @staticmethod - def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData: - """ - Removes small disconnected regions and holes in masks, then reruns - box NMS to remove any new duplicates. - - Edits mask_data in place. - - Requires open-cv as a dependency. - """ - if len(mask_data['rles']) == 0: - return mask_data - - # Filter small disconnected regions and holes - new_masks = [] - scores = [] - for rle in mask_data['rles']: - mask = rle_to_mask(rle) - - mask, changed = remove_small_regions(mask, min_area, mode='holes') - unchanged = not changed - mask, changed = remove_small_regions(mask, min_area, mode='islands') - unchanged = unchanged and not changed - - new_masks.append(torch.as_tensor(mask).unsqueeze(0)) - # Give score=0 to changed masks and score=1 to unchanged masks - # so NMS will prefer ones that didn't need postprocessing - scores.append(float(unchanged)) - - # Recalculate boxes and remove any new duplicates - masks = torch.cat(new_masks, dim=0) - boxes = batched_mask_to_box(masks) - keep_by_nms = batched_nms( - boxes.float(), - torch.as_tensor(scores), - torch.zeros_like(boxes[:, 0]), # categories - iou_threshold=nms_thresh, - ) - - # Only recalculate RLEs for masks that have changed - for i_mask in keep_by_nms: - if scores[i_mask] == 0.0: - mask_torch = masks[i_mask].unsqueeze(0) - mask_data['rles'][i_mask] = mask_to_rle_pytorch(mask_torch)[0] - mask_data['boxes'][i_mask] = boxes[i_mask] # update res directly - mask_data.filter(keep_by_nms) - - return mask_data diff --git a/ultralytics/vit/sam/modules/prompt_predictor.py b/ultralytics/vit/sam/modules/prompt_predictor.py deleted file mode 100644 index bf89893458..0000000000 --- a/ultralytics/vit/sam/modules/prompt_predictor.py +++ /dev/null @@ -1,242 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -from typing import Optional, Tuple - -import numpy as np -import torch - -from ..autosize import ResizeLongestSide -from .sam import Sam - - -class PromptPredictor: - - def __init__(self, sam_model: Sam) -> None: - """ - Uses SAM to calculate the image embedding for an image, and then - allow repeated, efficient mask prediction given prompts. - - Arguments: - sam_model (Sam): The model to use for mask prediction. - """ - super().__init__() - self.model = sam_model - self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) - self.reset_image() - - def set_image(self, image: np.ndarray, image_format: str = 'RGB') -> None: - """ - Calculates the image embeddings for the provided image, allowing - masks to be predicted with the 'predict' method. - - Arguments: - image (np.ndarray): The image for calculating masks. Expects an - image in HWC uint8 format, with pixel values in [0, 255]. - image_format (str): The color format of the image, in ['RGB', 'BGR']. - """ - assert image_format in {'RGB', 'BGR'}, f"image_format must be in ['RGB', 'BGR'], is {image_format}." - if image_format != self.model.image_format: - image = image[..., ::-1] - - # Transform the image to the form expected by the model - input_image = self.transform.apply_image(image) - input_image_torch = torch.as_tensor(input_image, device=self.device) - input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] - - self.set_torch_image(input_image_torch, image.shape[:2]) - - @torch.no_grad() - def set_torch_image(self, transformed_image: torch.Tensor, original_image_size: Tuple[int, ...]) -> None: - """ - Calculates the image embeddings for the provided image, allowing - masks to be predicted with the 'predict' method. Expects the input - image to be already transformed to the format expected by the model. - - Arguments: - transformed_image (torch.Tensor): The input image, with shape - 1x3xHxW, which has been transformed with ResizeLongestSide. - original_image_size (tuple(int, int)): The size of the image - before transformation, in (H, W) format. - """ - if len(transformed_image.shape) != 4 \ - or transformed_image.shape[1] != 3 \ - or max(*transformed_image.shape[2:]) != self.model.image_encoder.img_size: - raise ValueError('set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}.') - self.reset_image() - - self.original_size = original_image_size - self.input_size = tuple(transformed_image.shape[-2:]) - input_image = self.model.preprocess(transformed_image) - self.features = self.model.image_encoder(input_image) - self.is_image_set = True - - def predict( - self, - point_coords: Optional[np.ndarray] = None, - point_labels: Optional[np.ndarray] = None, - box: Optional[np.ndarray] = None, - mask_input: Optional[np.ndarray] = None, - multimask_output: bool = True, - return_logits: bool = False, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Predict masks for the given input prompts, using the currently set image. - - Arguments: - point_coords (np.ndarray, None): A Nx2 array of point prompts to the - model. Each point is in (X,Y) in pixels. - point_labels (np.ndarray, None): A length N array of labels for the - point prompts. 1 indicates a foreground point and 0 indicates a - background point. - box (np.ndarray, None): A length 4 array given a box prompt to the - model, in XYXY format. - mask_input (np.ndarray): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form 1xHxW, where - for SAM, H=W=256. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. - return_logits (bool): If true, returns un-thresholded masks logits - instead of a binary mask. - - Returns: - (np.ndarray): The output masks in CxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (np.ndarray): An array of length C containing the model's - predictions for the quality of each mask. - (np.ndarray): An array of shape CxHxW, where C is the number - of masks and H=W=256. These low resolution logits can be passed to - a subsequent iteration as mask input. - """ - if not self.is_image_set: - raise RuntimeError('An image must be set with .set_image(...) before mask prediction.') - - # Transform input prompts - coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None - if point_coords is not None: - assert (point_labels is not None), 'point_labels must be supplied if point_coords is supplied.' - point_coords = self.transform.apply_coords(point_coords, self.original_size) - coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) - labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) - coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] - if box is not None: - box = self.transform.apply_boxes(box, self.original_size) - box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) - box_torch = box_torch[None, :] - if mask_input is not None: - mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) - mask_input_torch = mask_input_torch[None, :, :, :] - - masks, iou_predictions, low_res_masks = self.predict_torch( - coords_torch, - labels_torch, - box_torch, - mask_input_torch, - multimask_output, - return_logits=return_logits, - ) - - masks_np = masks[0].detach().cpu().numpy() - iou_predictions_np = iou_predictions[0].detach().cpu().numpy() - low_res_masks_np = low_res_masks[0].detach().cpu().numpy() - return masks_np, iou_predictions_np, low_res_masks_np - - @torch.no_grad() - def predict_torch( - self, - point_coords: Optional[torch.Tensor], - point_labels: Optional[torch.Tensor], - boxes: Optional[torch.Tensor] = None, - mask_input: Optional[torch.Tensor] = None, - multimask_output: bool = True, - return_logits: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Predict masks for the given input prompts, using the currently set image. - Input prompts are batched torch tensors and are expected to already be - transformed to the input frame using ResizeLongestSide. - - Arguments: - point_coords (torch.Tensor, None): A BxNx2 array of point prompts to the - model. Each point is in (X,Y) in pixels. - point_labels (torch.Tensor, None): A BxN array of labels for the - point prompts. 1 indicates a foreground point and 0 indicates a - background point. - boxes (np.ndarray, None): A Bx4 array given a box prompt to the - model, in XYXY format. - mask_input (np.ndarray): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form Bx1xHxW, where - for SAM, H=W=256. Masks returned by a previous iteration of the - predict method do not need further transformation. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. - return_logits (bool): If true, returns un-thresholded masks logits - instead of a binary mask. - - Returns: - (torch.Tensor): The output masks in BxCxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (torch.Tensor): An array of shape BxC containing the model's - predictions for the quality of each mask. - (torch.Tensor): An array of shape BxCxHxW, where C is the number - of masks and H=W=256. These low res logits can be passed to - a subsequent iteration as mask input. - """ - if not self.is_image_set: - raise RuntimeError('An image must be set with .set_image(...) before mask prediction.') - - points = (point_coords, point_labels) if point_coords is not None else None - # Embed prompts - sparse_embeddings, dense_embeddings = self.model.prompt_encoder( - points=points, - boxes=boxes, - masks=mask_input, - ) - - # Predict masks - low_res_masks, iou_predictions = self.model.mask_decoder( - image_embeddings=self.features, - image_pe=self.model.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - ) - - # Upscale the masks to the original image resolution - masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) - - if not return_logits: - masks = masks > self.model.mask_threshold - - return masks, iou_predictions, low_res_masks - - def get_image_embedding(self) -> torch.Tensor: - """ - Returns the image embeddings for the currently set image, with - shape 1xCxHxW, where C is the embedding dimension and (H,W) are - the embedding spatial dimension of SAM (typically C=256, H=W=64). - """ - if not self.is_image_set: - raise RuntimeError('An image must be set with .set_image(...) to generate an embedding.') - assert self.features is not None, 'Features must exist if an image has been set.' - return self.features - - @property - def device(self) -> torch.device: - return self.model.device - - def reset_image(self) -> None: - """Resets the currently set image.""" - self.is_image_set = False - self.features = None - self.orig_h = None - self.orig_w = None - self.input_h = None - self.input_w = None diff --git a/ultralytics/vit/sam/modules/tiny_encoder.py b/ultralytics/vit/sam/modules/tiny_encoder.py new file mode 100644 index 0000000000..e3f51017f2 --- /dev/null +++ b/ultralytics/vit/sam/modules/tiny_encoder.py @@ -0,0 +1,653 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from ultralytics.yolo.utils.instance import to_2tuple + + +class Conv2d_BN(torch.nn.Sequential): + + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, + w.size(0), + w.shape[2:], + stride=self.c.stride, + padding=self.c.padding, + dilation=self.c.dilation, + groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +# NOTE: This module and timm package is needed only for training. +# from ultralytics.yolo.utils.checks import check_requirements +# check_requirements('timm') +# from timm.models.layers import DropPath as TimmDropPath +# from timm.models.layers import trunc_normal_ +# class DropPath(TimmDropPath): +# +# def __init__(self, drop_prob=None): +# super().__init__(drop_prob=drop_prob) +# self.drop_prob = drop_prob +# +# def __repr__(self): +# msg = super().__repr__() +# msg += f'(drop_prob={self.drop_prob})' +# return msg + + +class PatchEmbed(nn.Module): + + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + + def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + # NOTE: `DropPath` is needed only for training. + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c = 2 + if (out_dim == 320 or out_dim == 448 or out_dim == 576): + stride_c = 1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + + def __init__( + self, + dim, + input_resolution, + depth, + activation, + drop_path=0., + downsample=None, + use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv( + dim, + dim, + conv_expand_ratio, + activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + self.ab = self.ab.to(self.attention_biases.device) + + attn = ((q @ k.transpose(-2, -1)) * self.scale + + (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab)) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation (torch.nn): the activation function. Default: nn.GELU + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4., + drop=0., + drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + # NOTE: `DropPath` is needed only for training. + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_r = (self.window_size - W % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, + C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \ + f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}' + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation (torch.nn): the activation function. Default: nn.GELU + out_dim (int | optional): the output dimension of the layer. Default: None + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + drop=0., + drop_path=0., + downsample=None, + use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + +class LayerNorm2d(nn.Module): + + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class TinyViT(nn.Module): + + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dims=[96, 192, 384, 768], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size = img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict( + dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min(i_layer + 1, + len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer(num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # NOTE: This initialization is needed only for training. + # trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + B, _, C = x.size() + x = x.view(B, 64, 64, C) + x = x.permute(0, 3, 1, 2) + x = self.neck(x) + return x + + def forward(self, x): + x = self.forward_features(x) + return x diff --git a/ultralytics/vit/sam/predict.py b/ultralytics/vit/sam/predict.py index 063955de79..47a9d55c42 100644 --- a/ultralytics/vit/sam/predict.py +++ b/ultralytics/vit/sam/predict.py @@ -2,32 +2,298 @@ import numpy as np import torch +import torch.nn.functional as F +import torchvision +from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.results import Results +from ultralytics.yolo.utils import DEFAULT_CFG, ops from ultralytics.yolo.utils.torch_utils import select_device -from .modules.mask_generator import SamAutomaticMaskGenerator +from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, + generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks) +from .build import build_sam class Predictor(BasePredictor): + def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None): + overrides.update(dict(task='segment', mode='predict', imgsz=1024)) + super().__init__(cfg, overrides, _callbacks) + # SAM needs retina_masks=True, or the results would be a mess. + self.args.retina_masks = True + # Args for set_image + self.im = None + self.features = None + # Args for segment everything + self.segment_all = False + def preprocess(self, im): - """Prepares input image for inference.""" - # TODO: Only support bs=1 for now - # im = ResizeLongestSide(1024).apply_image(im[0]) - # im = torch.as_tensor(im, device=self.device) - # im = im.permute(2, 0, 1).contiguous()[None, :, :, :] - return im[0] + """Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + """ + if self.im is not None: + return self.im + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + + img = im.to(self.device) + img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 + if not_tensor: + img = (img - self.mean) / self.std + return img + + def pre_transform(self, im): + """Pre-transform input image before inference. + + Args: + im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + + Return: A list of transformed imgs. + """ + assert len(im) == 1, 'SAM model has not supported batch inference yet!' + return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im] + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): + """ + Predict masks for the given input prompts, using the currently set image. + + Args: + im (torch.Tensor): The preprocessed image, (N, C, H, W). + bboxes (np.ndarray | List, None): (N, 4), in XYXY format. + points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. + labels (np.ndarray | List, None): (N, ), labels for the point prompts. + 1 indicates a foreground point and 0 indicates a background point. + masks (np.ndarray, None): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form (N, H, W), where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if all([i is None for i in [bboxes, points, masks]]): + return self.generate(im, *args, **kwargs) + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) + + def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): + """ + Predict masks for the given input prompts, using the currently set image. + + Args: + im (torch.Tensor): The preprocessed image, (N, C, H, W). + bboxes (np.ndarray | List, None): (N, 4), in XYXY format. + points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. + labels (np.ndarray | List, None): (N, ), labels for the point prompts. + 1 indicates a foreground point and 0 indicates a background point. + masks (np.ndarray, None): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form (N, H, W), where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + features = self.model.image_encoder(im) if self.features is None else self.features + + src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(points.shape[0]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + points *= r + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None, :], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device) + masks = masks[:, None, :, :] + + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=bboxes, + masks=masks, + ) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def generate(self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7): + """Segment the whole image. + + Args: + im (torch.Tensor): The preprocessed image, (N, C, H, W). + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray), None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + points_stride (int, None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_batch_size (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + conf_thres (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + """ + self.segment_all = True + ih, iw = im.shape[2:] + crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) + if point_grids is None: + point_grids = build_all_layer_point_grids( + points_stride, + crop_n_layers, + crop_downscale_factor, + ) + pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] + for crop_region, layer_idx in zip(crop_regions, layer_idxs): + x1, y1, x2, y2 = crop_region + w, h = x2 - x1, y2 - y1 + area = torch.tensor(w * h, device=im.device) + points_scale = np.array([[w, h]]) # w, h + # Crop image and interpolate to input size + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False) + # (num_points, 2) + points_for_image = point_grids[layer_idx] * points_scale + crop_masks, crop_scores, crop_bboxes = [], [], [] + for (points, ) in batch_iterator(points_batch_size, points_for_image): + pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) + # Interpolate predicted masks to input size + pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0] + idx = pred_score > conf_thres + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + + stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold, + stability_score_offset) + idx = stability_score > stability_score_thresh + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + # Bool type is much more memory-efficient. + pred_mask = pred_mask > self.model.mask_threshold + # (N, 4) + pred_bbox = batched_mask_to_box(pred_mask).float() + keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) + if not torch.all(keep_mask): + pred_bbox = pred_bbox[keep_mask] + pred_mask = pred_mask[keep_mask] + pred_score = pred_score[keep_mask] + + crop_masks.append(pred_mask) + crop_bboxes.append(pred_bbox) + crop_scores.append(pred_score) + + # Do nms within this crop + crop_masks = torch.cat(crop_masks) + crop_bboxes = torch.cat(crop_bboxes) + crop_scores = torch.cat(crop_scores) + keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS + crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) + crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) + crop_scores = crop_scores[keep] + + pred_masks.append(crop_masks) + pred_bboxes.append(crop_bboxes) + pred_scores.append(crop_scores) + region_areas.append(area.expand(len(crop_masks))) + + pred_masks = torch.cat(pred_masks) + pred_bboxes = torch.cat(pred_bboxes) + pred_scores = torch.cat(pred_scores) + region_areas = torch.cat(region_areas) + + # Remove duplicate masks between crops + if len(crop_regions) > 1: + scores = 1 / region_areas + keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) + pred_masks = pred_masks[keep] + pred_bboxes = pred_bboxes[keep] + pred_scores = pred_scores[keep] + + return pred_masks, pred_scores, pred_bboxes def setup_model(self, model): """Set up YOLO model with specified thresholds and device.""" device = select_device(self.args.device) + if model is None: + model = build_sam(self.args.model) model.eval() - self.model = SamAutomaticMaskGenerator(model.to(device), - pred_iou_thresh=self.args.conf, - box_nms_thresh=self.args.iou) + self.model = model.to(device) self.device = device + self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) + self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) # TODO: Temporary settings for compatibility self.model.pt = False self.model.triton = False @@ -35,20 +301,96 @@ class Predictor(BasePredictor): self.model.fp16 = False self.done_warmup = True - def postprocess(self, preds, path, orig_imgs): + def postprocess(self, preds, img, orig_imgs): """Postprocesses inference output predictions to create detection masks for objects.""" - names = dict(enumerate(list(range(len(preds))))) + # (N, 1, H, W), (N, 1) + pred_masks, pred_scores = preds[:2] + pred_bboxes = preds[2] if self.segment_all else None + names = dict(enumerate([str(i) for i in range(len(pred_masks))])) results = [] - # TODO - for i, pred in enumerate([preds]): - masks = torch.from_numpy(np.stack([p['segmentation'] for p in pred], axis=0)) + for i, masks in enumerate([pred_masks]): orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) + + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool path = self.batch[0] img_path = path[i] if isinstance(path, list) else path - results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks)) + results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) + # Reset segment-all mode. + self.segment_all = False return results - # def __call__(self, source=None, model=None, stream=False): - # frame = cv2.imread(source) - # preds = self.model.generate(frame) - # return self.postprocess(preds, source, frame) + def setup_source(self, source): + """Sets up source and inference mode.""" + if source is not None: + super().setup_source(source) + + def set_image(self, image): + """Set image in advance. + Args: + + image (str | np.ndarray): image file path or np.ndarray image by cv2. + """ + if self.model is None: + model = build_sam(self.args.model) + self.setup_model(model) + self.setup_source(image) + assert len(self.dataset) == 1, '`set_image` only supports setting one image!' + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.model.image_encoder(im) + self.im = im + break + + def reset_image(self): + self.im = None + self.features = None + + @staticmethod + def remove_small_regions(masks, min_area=0, nms_thresh=0.7): + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. Requires open-cv as a dependency. + + Args: + masks (torch.Tensor): Masks, (N, H, W). + min_area (int): Minimum area threshold. + nms_thresh (float): NMS threshold. + """ + if len(masks) == 0: + return masks + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in masks: + mask = mask.cpu().numpy() + mask, changed = remove_small_regions(mask, min_area, mode='holes') + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode='islands') + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + new_masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(new_masks) + keep = torchvision.ops.nms( + boxes.float(), + torch.as_tensor(scores), + nms_thresh, + ) + + # Only recalculate masks for masks that have changed + for i in keep: + if scores[i] == 0.0: + masks[i] = new_masks[i] + + return masks[keep] diff --git a/ultralytics/yolo/data/annotator.py b/ultralytics/yolo/data/annotator.py index e841df631a..f69f325bed 100644 --- a/ultralytics/yolo/data/annotator.py +++ b/ultralytics/yolo/data/annotator.py @@ -1,8 +1,6 @@ from pathlib import Path -from ultralytics import YOLO -from ultralytics.vit.sam import PromptPredictor, build_sam -from ultralytics.yolo.utils.torch_utils import select_device +from ultralytics import SAM, YOLO def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None): @@ -16,33 +14,21 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir (str | None | optional): Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. """ - device = select_device(device) det_model = YOLO(det_model) - sam_model = build_sam(sam_model) - det_model.to(device) - sam_model.to(device) + sam_model = SAM(sam_model) if not output_dir: output_dir = Path(str(data)).parent / 'labels' Path(output_dir).mkdir(exist_ok=True, parents=True) - prompt_predictor = PromptPredictor(sam_model) - det_results = det_model(data, stream=True) + det_results = det_model(data, stream=True, device=device) for result in det_results: boxes = result.boxes.xyxy # Boxes object for bbox outputs class_ids = result.boxes.cls.int().tolist() # noqa if len(class_ids): - prompt_predictor.set_image(result.orig_img) - masks, _, _ = prompt_predictor.predict_torch( - point_coords=None, - point_labels=None, - boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]), - multimask_output=False, - ) - - result.update(masks=masks.squeeze(1)) - segments = result.masks.xyn # noqa + sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device) + segments = sam_results[0].masks.xyn # noqa with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f: for i in range(len(segments)): diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index e0c0a8f91b..d6881595bf 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -538,13 +538,14 @@ class RandomFlip: class LetterBox: """Resize image and padding for detection, instance segmentation, pose.""" - def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32): + def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32): """Initialize LetterBox object with specific parameters.""" self.new_shape = new_shape self.auto = auto self.scaleFill = scaleFill self.scaleup = scaleup self.stride = stride + self.center = center # Put the image in the middle or top-left def __call__(self, labels=None, image=None): """Return updated labels and image with added border.""" @@ -572,15 +573,16 @@ class LetterBox: new_unpad = (new_shape[1], new_shape[0]) ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios - dw /= 2 # divide padding into 2 sides - dh /= 2 + if self.center: + dw /= 2 # divide padding into 2 sides + dh /= 2 if labels.get('ratio_pad'): labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation if shape[::-1] != new_unpad: # resize img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) - top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) - left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 0f6eb0d381..e326e5c1fe 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -131,6 +131,11 @@ class BasePredictor: img /= 255 # 0 - 255 to 0.0 - 1.0 return img + def inference(self, im, *args, **kwargs): + visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem, + mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False + return self.model(im, augment=self.args.augment, visualize=visualize) + def pre_transform(self, im): """Pre-transform input image before inference. @@ -181,13 +186,13 @@ class BasePredictor: """Post-processes predictions for an image and returns them.""" return preds - def __call__(self, source=None, model=None, stream=False): + def __call__(self, source=None, model=None, stream=False, *args, **kwargs): """Performs inference on an image or stream.""" self.stream = stream if stream: - return self.stream_inference(source, model) + return self.stream_inference(source, model, *args, **kwargs) else: - return list(self.stream_inference(source, model)) # merge list of Result into one + return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one def predict_cli(self, source=None, model=None): """Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode.""" @@ -209,7 +214,7 @@ class BasePredictor: self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs @smart_inference_mode() - def stream_inference(self, source=None, model=None): + def stream_inference(self, source=None, model=None, *args, **kwargs): """Streams real-time inference on camera feed and saves results to file.""" if self.args.verbose: LOGGER.info('') @@ -236,8 +241,6 @@ class BasePredictor: self.run_callbacks('on_predict_batch_start') self.batch = batch path, im0s, vid_cap, s = batch - visualize = increment_path(self.save_dir / Path(path[0]).stem, - mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False # Preprocess with profilers[0]: @@ -245,7 +248,7 @@ class BasePredictor: # Inference with profilers[1]: - preds = self.model(im, augment=self.args.augment, visualize=visualize) + preds = self.inference(im, *args, **kwargs) # Postprocess with profilers[2]: diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index 7b33def91a..e934730d62 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -170,7 +170,7 @@ class Results(SimpleClass): font='Arial.ttf', pil=False, img=None, - img_gpu=None, + im_gpu=None, kpt_line=True, labels=True, boxes=True, @@ -188,7 +188,7 @@ class Results(SimpleClass): font (str): The font to use for the text. pil (bool): Whether to return the image as a PIL Image. img (numpy.ndarray): Plot to another image. if not, plot to original image. - img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting. + im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting. kpt_line (bool): Whether to draw lines connecting keypoints. labels (bool): Whether to plot the label of bounding boxes. boxes (bool): Whether to plot the bounding boxes. @@ -226,12 +226,12 @@ class Results(SimpleClass): # Plot Segment results if pred_masks and show_masks: - if img_gpu is None: + if im_gpu is None: img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) - img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute( + im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute( 2, 0, 1).flip(0).contiguous() / 255 idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) - annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu) + annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) # Plot Detect results if pred_boxes and show_boxes: diff --git a/ultralytics/yolo/fastsam/utils.py b/ultralytics/yolo/fastsam/utils.py index c5b6cc235e..dcc71dcfa9 100644 --- a/ultralytics/yolo/fastsam/utils.py +++ b/ultralytics/yolo/fastsam/utils.py @@ -8,12 +8,12 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): Adjust bounding boxes to stick to image border if they are within a certain threshold. Args: - boxes: (n, 4) - image_shape: (height, width) - threshold: pixel threshold + boxes (torch.Tensor): (n, 4) + image_shape (tuple): (height, width) + threshold (int): pixel threshold Returns: - adjusted_boxes: adjusted bounding boxes + adjusted_boxes (torch.Tensor): adjusted bounding boxes """ # Image dimensions @@ -32,11 +32,11 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. Args: - box1: (4, ) - boxes: (n, 4) + box1 (torch.Tensor): (4, ) + boxes (torch.Tensor): (n, 4) Returns: - high_iou_indices: Indices of boxes with IoU > thres + high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres """ boxes = adjust_bboxes_to_image_border(boxes, image_shape) # obtain coordinates for intersections diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 53f58cfddc..c13192157f 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -21,7 +21,8 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('' [f'yolo_nas_{k}.pt' for k in 'sml'] + \ [f'sam_{k}.pt' for k in 'bl'] + \ [f'FastSAM-{k}.pt' for k in 'sx'] + \ - [f'rtdetr-{k}.pt' for k in 'lx'] + [f'rtdetr-{k}.pt' for k in 'lx'] + \ + ['mobile_sam.pt'] GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES] diff --git a/ultralytics/yolo/utils/instance.py b/ultralytics/yolo/utils/instance.py index 3566f6e2ea..68f9613eee 100644 --- a/ultralytics/yolo/utils/instance.py +++ b/ultralytics/yolo/utils/instance.py @@ -20,6 +20,7 @@ def _ntuple(n): return parse +to_2tuple = _ntuple(2) to_4tuple = _ntuple(4) # `xyxy` means left top and right bottom diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index b9199ad2f4..bb9ca49a8f 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -92,7 +92,7 @@ def segment2box(segment, width=640, height=640): 4, dtype=segment.dtype) # xyxy -def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True): """ Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in (img1_shape) to the shape of a different image (img0_shape). @@ -103,6 +103,8 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): img0_shape (tuple): the shape of the target image, in the format of (height, width). ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be calculated based on the size difference between the two images. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. Returns: boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) @@ -115,8 +117,9 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): gain = ratio_pad[0][0] pad = ratio_pad[1] - boxes[..., [0, 2]] -= pad[0] # x padding - boxes[..., [1, 3]] -= pad[1] # y padding + if padding: + boxes[..., [0, 2]] -= pad[0] # x padding + boxes[..., [1, 3]] -= pad[1] # y padding boxes[..., :4] /= gain clip_boxes(boxes, img0_shape) return boxes @@ -552,7 +555,7 @@ def crop_mask(masks, boxes): It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box Args: - masks (torch.Tensor): [h, w, n] tensor of masks + masks (torch.Tensor): [n, h, w] tensor of masks boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form Returns: @@ -634,18 +637,36 @@ def process_mask_native(protos, masks_in, bboxes, shape): """ c, mh, mw = protos.shape # CHW masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) - gain = min(mh / shape[0], mw / shape[1]) # gain = old / new - pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding - top, left = int(pad[1]), int(pad[0]) # y, x - bottom, right = int(mh - pad[1]), int(mw - pad[0]) - masks = masks[:, top:bottom, left:right] - - masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW + masks = scale_masks(masks[None], shape)[0] # CHW masks = crop_mask(masks, bboxes) # CHW return masks.gt_(0.5) -def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False): +def scale_masks(masks, shape, padding=True): + """ + Rescale segment masks to shape. + + Args: + masks (torch.Tensor): (N, C, H, W). + shape (tuple): Height and width. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + """ + mh, mw = masks.shape[2:] + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding + if padding: + pad[0] /= 2 + pad[1] /= 2 + top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x + bottom, right = (int(mh - pad[1]), int(mw - pad[0])) + masks = masks[..., top:bottom, left:right] + + masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW + return masks + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): """ Rescale segment coordinates (xyxy) from img1_shape to img0_shape @@ -655,6 +676,8 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False img0_shape (tuple): the shape of the image that the segmentation is being applied to ratio_pad (tuple): the ratio of the image size to the padded image size. normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. Returns: coords (torch.Tensor): the segmented image. @@ -666,8 +689,9 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False gain = ratio_pad[0][0] pad = ratio_pad[1] - coords[..., 0] -= pad[0] # x padding - coords[..., 1] -= pad[1] # y padding + if padding: + coords[..., 0] -= pad[0] # x padding + coords[..., 1] -= pad[1] # y padding coords[..., 0] /= gain coords[..., 1] /= gain clip_coords(coords, img0_shape)