@ -1,4 +1,12 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Generate predictions using the Segment Anything Model ( SAM ) .
SAM is an advanced image segmentation model offering features like promptable segmentation and zero - shot performance .
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
using SAM . It forms an integral part of the Ultralytics framework and is designed for high - performance , real - time image
segmentation tasks .
"""
import numpy as np
import torch
@ -18,71 +26,86 @@ from .build import build_sam
class Predictor ( BasePredictor ) :
"""
A prediction class for segmentation tasks , extending the BasePredictor .
Predictor class for the Segment Anything Model ( SAM ) , extending BasePredictor .
This class serves a s an interface for model inference for segmentation tasks .
It can preprocess input images , perform inference , and postprocess the output .
It also supports handling various types of input prompts including bounding boxes ,
points , and low - resolution masks for better prediction results .
The class provide s an interface for model inference tailored to image segmentation tasks .
With advanced architecture and promptable segmentation capabilities , it facilitates flexible and real - time
mask generation . The class is capable of working with various types of prompts such as bounding boxes ,
points , and low - resolution masks .
Attributes :
cfg ( dict ) : Configuration dictionary .
overrides ( dict ) : Dictionary of overriding values .
_callbacks ( dict ) : Dictionary of callback functions .
args ( namespace ) : Argument namespace .
im ( torch . Tensor ) : Preprocessed image for current prediction .
features ( torch . Tensor ) : Image features .
prompts ( dict ) : Dictionary of prompts like bboxes , points , mask s.
segment_all ( bool ) : Whether to perform segmentation on all objects or not .
cfg ( dict ) : Configuration dictionary specifying model and task - related parameters .
overrides ( dict ) : Dictionary containing values that override the default configuration .
_callbacks ( dict ) : Dictionary of user - defined callback functions to augment behavior .
args ( namespace ) : Namespace to hold command - line arguments or other operational variables .
im ( torch . Tensor ) : Preprocessed input image tensor .
features ( torch . Tensor ) : Extracted image features used for inference .
prompts ( dict ) : Collection of various prompt types , such as bounding boxes and point s.
segment_all ( bool ) : Flag to control whether to segment all objects in the image or only specified ones .
"""
def __init__ ( self , cfg = DEFAULT_CFG , overrides = None , _callbacks = None ) :
""" Initializes the Predictor class with default or provided configuration, overrides, and callbacks. """
"""
Initialize the Predictor with configuration , overrides , and callbacks .
The method sets up the Predictor object and applies any configuration overrides or callbacks provided . It
initializes task - specific settings for SAM , such as retina_masks being set to True for optimal results .
Args :
cfg ( dict ) : Configuration dictionary .
overrides ( dict , optional ) : Dictionary of values to override default configuration .
_callbacks ( dict , optional ) : Dictionary of callback functions to customize behavior .
"""
if overrides is None :
overrides = { }
overrides . update ( dict ( task = ' segment ' , mode = ' predict ' , imgsz = 1024 ) )
super ( ) . __init__ ( cfg , overrides , _callbacks )
# SAM needs retina_masks=True, or the results would be a mess.
self . args . retina_masks = True
# Args for set_image
self . im = None
self . features = None
# Args for set_prompts
self . prompts = { }
# Args for segment everything
self . segment_all = False
def preprocess ( self , im ) :
"""
Prepares input image before inference .
Preprocess the input image for model inference .
The method prepares the input image by applying transformations and normalization .
It supports both torch . Tensor and list of np . ndarray as input formats .
Args :
im ( torch . Tensor | List ( np . ndarray ) ) : BCHW for tensor , [ ( HWC ) x B ] for list .
im ( torch . Tensor | List [ np . ndarray ] ) : BCHW tensor format or list of HWC numpy arrays .
Returns :
torch . Tensor : The preprocessed image tensor .
"""
if self . im is not None :
return self . im
not_tensor = not isinstance ( im , torch . Tensor )
if not_tensor :
im = np . stack ( self . pre_transform ( im ) )
im = im [ . . . , : : - 1 ] . transpose ( ( 0 , 3 , 1 , 2 ) ) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np . ascontiguousarray ( im ) # contiguous
im = im [ . . . , : : - 1 ] . transpose ( ( 0 , 3 , 1 , 2 ) )
im = np . ascontiguousarray ( im )
im = torch . from_numpy ( im )
im = im . to ( self . device )
im = im . half ( ) if self . model . fp16 else im . float ( ) # uint8 to fp16/32
im = im . half ( ) if self . model . fp16 else im . float ( )
if not_tensor :
im = ( im - self . mean ) / self . std
return im
def pre_transform ( self , im ) :
"""
Pre - transform input image before inference .
Perform initial transformations on the input image for preprocessing .
The method applies transformations such as resizing to prepare the image for further preprocessing .
Currently , batched inference is not supported ; hence the list length should be 1.
Args :
im ( List ( np . ndarray ) ) : ( N , 3 , h , w ) for tensor , [ ( h , w , 3 ) x N ] for list .
im ( List [ np . ndarray ] ) : List containing images in HWC numpy array forma t.
Returns :
( list ) : A l ist of transformed images .
List [ np . ndarray ] : L ist of transformed images .
"""
assert len ( im ) == 1 , ' SAM model does not currently support batched inference '
letterbox = LetterBox ( self . args . imgsz , auto = False , center = False )
@ -90,69 +113,52 @@ class Predictor(BasePredictor):
def inference ( self , im , bboxes = None , points = None , labels = None , masks = None , multimask_output = False , * args , * * kwargs ) :
"""
Predict masks for the given input prompts , using the currently set image .
Perform image segmentation inference based on the given input cues , using the currently loaded image . This
method leverages SAM ' s (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
mask decoder for real - time and promptable segmentation tasks .
Args :
im ( torch . Tensor ) : The preprocessed image , ( N , C , H , W ) .
bboxes ( np . ndarray | List , None ) : ( N , 4 ) , in XYXY format .
points ( np . ndarray | List , None ) : ( N , 2 ) , Each point is in ( X , Y ) in pixels .
labels ( np . ndarray | List , None ) : ( N , ) , labels for the point prompts .
1 indicates a foreground point and 0 indicates a background point .
masks ( np . ndarray , None ) : A low resolution mask input to the model , typically
coming from a previous prediction iteration . Has form ( N , H , W ) , where
for SAM , H = W = 256.
multimask_output ( bool ) : If true , the model will return three masks .
For ambiguous input prompts ( such as a single click ) , this will often
produce better masks than a single prediction . If only a single
mask is needed , the model ' s predicted quality score can be used
to select the best mask . For non - ambiguous prompts , such as multiple
input prompts , multimask_output = False can give better results .
im ( torch . Tensor ) : The preprocessed input image in tensor format , with shape ( N , C , H , W ) .
bboxes ( np . ndarray | List , optional ) : Bounding boxes with shape ( N , 4 ) , in XYXY format .
points ( np . ndarray | List , optional ) : Points indicating object locations with shape ( N , 2 ) , in pixel coordinates .
labels ( np . ndarray | List , optional ) : Labels for point prompts , shape ( N , ) . 1 for foreground and 0 for background .
masks ( np . ndarray , optional ) : Low - resolution masks from previous predictions . Shape should be ( N , H , W ) . For SAM , H = W = 256.
multimask_output ( bool , optional ) : Flag to return multiple masks . Helpful for ambiguous prompts . Defaults to False .
Returns :
( np . ndarray ) : The output masks in CxHxW format , where C is the
number of masks , and ( H , W ) is the original image size .
( np . ndarray ) : An array of length C containing the model ' s
predictions for the quality of each mask .
( np . ndarray ) : An array of shape CxHxW , where C is the number
of masks and H = W = 256. These low resolution logits can be passed to
a subsequent iteration as mask input .
tuple : Contains the following three elements .
- np . ndarray : The output masks in shape CxHxW , where C is the number of generated masks .
- np . ndarray : An array of length C containing quality scores predicted by the model for each mask .
- np . ndarray : Low - resolution logits of shape CxHxW for subsequent inference , where H = W = 256.
"""
# Get prompts from self.prompts first
# Override prompts if any stored in self.prompts
bboxes = self . prompts . pop ( ' bboxes ' , bboxes )
points = self . prompts . pop ( ' points ' , points )
masks = self . prompts . pop ( ' masks ' , masks )
if all ( i is None for i in [ bboxes , points , masks ] ) :
return self . generate ( im , * args , * * kwargs )
return self . prompt_inference ( im , bboxes , points , labels , masks , multimask_output )
def prompt_inference ( self , im , bboxes = None , points = None , labels = None , masks = None , multimask_output = False ) :
"""
Predict masks for the given input prompts , using the currently set image .
Internal function for image segmentation inference based on cues like bounding boxes , points , and masks .
Leverages SAM ' s specialized architecture for prompt-based, real-time segmentation.
Args :
im ( torch . Tensor ) : The preprocessed image , ( N , C , H , W ) .
bboxes ( np . ndarray | List , None ) : ( N , 4 ) , in XYXY format .
points ( np . ndarray | List , None ) : ( N , 2 ) , Each point is in ( X , Y ) in pixels .
labels ( np . ndarray | List , None ) : ( N , ) , labels for the point prompts .
1 indicates a foreground point and 0 indicates a background point .
masks ( np . ndarray , None ) : A low resolution mask input to the model , typically
coming from a previous prediction iteration . Has form ( N , H , W ) , where
for SAM , H = W = 256.
multimask_output ( bool ) : If true , the model will return three masks .
For ambiguous input prompts ( such as a single click ) , this will often
produce better masks than a single prediction . If only a single
mask is needed , the model ' s predicted quality score can be used
to select the best mask . For non - ambiguous prompts , such as multiple
input prompts , multimask_output = False can give better results .
im ( torch . Tensor ) : The preprocessed input image in tensor format , with shape ( N , C , H , W ) .
bboxes ( np . ndarray | List , optional ) : Bounding boxes with shape ( N , 4 ) , in XYXY format .
points ( np . ndarray | List , optional ) : Points indicating object locations with shape ( N , 2 ) , in pixel coordinates .
labels ( np . ndarray | List , optional ) : Labels for point prompts , shape ( N , ) . 1 for foreground and 0 for background .
masks ( np . ndarray , optional ) : Low - resolution masks from previous predictions . Shape should be ( N , H , W ) . For SAM , H = W = 256.
multimask_output ( bool , optional ) : Flag to return multiple masks . Helpful for ambiguous prompts . Defaults to False .
Returns :
( np . ndarray ) : The output masks in CxHxW format , where C is the
number of masks , and ( H , W ) is the original image size .
( np . ndarray ) : An array of length C containing the model ' s
predictions for the quality of each mask .
( np . ndarray ) : An array of shape CxHxW , where C is the number
of masks and H = W = 256. These low resolution logits can be passed to
a subsequent iteration as mask input .
tuple : Contains the following three elements .
- np . ndarray : The output masks in shape CxHxW , where C is the number of generated masks .
- np . ndarray : An array of length C containing quality scores predicted by the model for each mask .
- np . ndarray : Low - resolution logits of shape CxHxW for subsequent inference , where H = W = 256.
"""
features = self . model . image_encoder ( im ) if self . features is None else self . features
@ -178,11 +184,7 @@ class Predictor(BasePredictor):
points = ( points , labels ) if points is not None else None
# Embed prompts
sparse_embeddings , dense_embeddings = self . model . prompt_encoder (
points = points ,
boxes = bboxes ,
masks = masks ,
)
sparse_embeddings , dense_embeddings = self . model . prompt_encoder ( points = points , boxes = bboxes , masks = masks )
# Predict masks
pred_masks , pred_scores = self . model . mask_decoder (
@ -210,46 +212,35 @@ class Predictor(BasePredictor):
stability_score_offset = 0.95 ,
crop_nms_thresh = 0.7 ) :
"""
Segment the whole image .
Perform image segmentation using the Segment Anything Model ( SAM ) .
This function segments an entire image into constituent parts by leveraging SAM ' s advanced architecture
and real - time performance capabilities . It can optionally work on image crops for finer segmentation .
Args :
im ( torch . Tensor ) : The preprocessed image , ( N , C , H , W ) .
crop_n_layers ( int ) : If > 0 , mask prediction will be run again on
crops of the image . Sets the number of layers to run , where each
layer has 2 * * i_layer number of image crops .
crop_overlap_ratio ( float ) : Sets the degree to which crops overlap .
In the first crop layer , crops will overlap by this fraction of
the image length . Later layers with more crops scale down this overlap .
crop_downscale_factor ( int ) : The number of points - per - side
sampled in layer n is scaled down by crop_n_points_downscale_factor * * n .
point_grids ( list ( np . ndarray ) , None ) : A list over explicit grids
of points used for sampling , normalized to [ 0 , 1 ] . The nth grid in the
list is used in the nth crop layer . Exclusive with points_per_side .
points_stride ( int , None ) : The number of points to be sampled
along one side of the image . The total number of points is
points_per_side * * 2. If None , ' point_grids ' must provide explicit
point sampling .
points_batch_size ( int ) : Sets the number of points run simultaneously
by the model . Higher numbers may be faster but use more GPU memory .
conf_thres ( float ) : A filtering threshold in [ 0 , 1 ] , using the
model ' s predicted mask quality.
stability_score_thresh ( float ) : A filtering threshold in [ 0 , 1 ] , using
the stability of the mask under changes to the cutoff used to binarize
the model ' s mask predictions.
stability_score_offset ( float ) : The amount to shift the cutoff when
calculated the stability score .
crop_nms_thresh ( float ) : The box IoU cutoff used by non - maximal
suppression to filter duplicate masks between different crops .
im ( torch . Tensor ) : Input tensor representing the preprocessed image with dimensions ( N , C , H , W ) .
crop_n_layers ( int ) : Specifies the number of layers for additional mask predictions on image crops .
Each layer produces 2 * * i_layer number of image crops .
crop_overlap_ratio ( float ) : Determines the extent of overlap between crops . Scaled down in subsequent layers .
crop_downscale_factor ( int ) : Scaling factor for the number of sampled points - per - side in each layer .
point_grids ( list [ np . ndarray ] , optional ) : Custom grids for point sampling normalized to [ 0 , 1 ] .
Used in the nth crop layer .
points_stride ( int , optional ) : Number of points to sample along each side of the image .
Exclusive with ' point_grids ' .
points_batch_size ( int ) : Batch size for the number of points processed simultaneously .
conf_thres ( float ) : Confidence threshold [ 0 , 1 ] for filtering based on the model ' s mask quality prediction.
stability_score_thresh ( float ) : Stability threshold [ 0 , 1 ] for mask filtering based on mask stability .
stability_score_offset ( float ) : Offset value for calculating stability score .
crop_nms_thresh ( float ) : IoU cutoff for Non - Maximum Suppression ( NMS ) to remove duplicate masks between crops .
Returns :
tuple : A tuple containing segmented masks , confidence scores , and bounding boxes .
"""
self . segment_all = True
ih , iw = im . shape [ 2 : ]
crop_regions , layer_idxs = generate_crop_boxes ( ( ih , iw ) , crop_n_layers , crop_overlap_ratio )
if point_grids is None :
point_grids = build_all_layer_point_grids (
points_stride ,
crop_n_layers ,
crop_downscale_factor ,
)
point_grids = build_all_layer_point_grids ( points_stride , crop_n_layers , crop_downscale_factor )
pred_masks , pred_scores , pred_bboxes , region_areas = [ ] , [ ] , [ ] , [ ]
for crop_region , layer_idx in zip ( crop_regions , layer_idxs ) :
x1 , y1 , x2 , y2 = crop_region
@ -312,7 +303,22 @@ class Predictor(BasePredictor):
return pred_masks , pred_scores , pred_bboxes
def setup_model ( self , model , verbose = True ) :
""" Set up YOLO model with specified thresholds and device. """
"""
Initializes the Segment Anything Model ( SAM ) for inference .
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
parameters for image normalization and other Ultralytics compatibility settings .
Args :
model ( torch . nn . Module ) : A pre - trained SAM model . If None , a model will be built based on configuration .
verbose ( bool ) : If True , prints selected device information .
Attributes :
model ( torch . nn . Module ) : The SAM model allocated to the chosen device for inference .
device ( torch . device ) : The device to which the model and tensors are allocated .
mean ( torch . Tensor ) : The mean values for image normalization .
std ( torch . Tensor ) : The standard deviation values for image normalization .
"""
device = select_device ( self . args . device , verbose = verbose )
if model is None :
model = build_sam ( self . args . model )
@ -321,7 +327,8 @@ class Predictor(BasePredictor):
self . device = device
self . mean = torch . tensor ( [ 123.675 , 116.28 , 103.53 ] ) . view ( - 1 , 1 , 1 ) . to ( device )
self . std = torch . tensor ( [ 58.395 , 57.12 , 57.375 ] ) . view ( - 1 , 1 , 1 ) . to ( device )
# TODO: Temporary settings for compatibility
# Ultralytics compatibility settings
self . model . pt = False
self . model . triton = False
self . model . stride = 32
@ -329,7 +336,20 @@ class Predictor(BasePredictor):
self . done_warmup = True
def postprocess ( self , preds , img , orig_imgs ) :
""" Post-processes inference output predictions to create detection masks for objects. """
"""
Post - processes SAM ' s inference outputs to generate object detection masks and bounding boxes.
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions . The
SAM model uses advanced architecture and promptable segmentation tasks to achieve real - time performance .
Args :
preds ( tuple ) : The output from SAM model inference , containing masks , scores , and optional bounding boxes .
img ( torch . Tensor ) : The processed input image tensor .
orig_imgs ( list | torch . Tensor ) : The original , unprocessed images .
Returns :
( list ) : List of Results objects containing detection masks , bounding boxes , and other metadata .
"""
# (N, 1, H, W), (N, 1)
pred_masks , pred_scores = preds [ : 2 ]
pred_bboxes = preds [ 2 ] if self . segment_all else None
@ -355,15 +375,30 @@ class Predictor(BasePredictor):
return results
def setup_source ( self , source ) :
""" Sets up source and inference mode. """
"""
Sets up the data source for inference .
This method configures the data source from which images will be fetched for inference . The source could be a
directory , a video file , or other types of image data sources .
Args :
source ( str | Path ) : The path to the image data source for inference .
"""
if source is not None :
super ( ) . setup_source ( source )
def set_image ( self , image ) :
""" Set image in advance.
"""
Preprocesses and sets a single image for inference .
This function sets up the model if not already initialized , configures the data source to the specified image ,
and preprocesses the image for feature extraction . Only one image can be set at a time .
Args :
image ( str | np . ndarray ) : Image file path as a string , or a np . ndarray image read by cv2 .
image ( str | np . ndarray ) : image file path or np . ndarray image by cv2 .
Raises :
AssertionError : If more than one image is set .
"""
if self . model is None :
model = build_sam ( self . args . model )
@ -388,17 +423,20 @@ class Predictor(BasePredictor):
@staticmethod
def remove_small_regions ( masks , min_area = 0 , nms_thresh = 0.7 ) :
"""
Removes small disconnected regions and holes in masks , then reruns box NMS to remove any new duplicates .
Requires open - cv as a dependency .
Perform post - processing on segmentation masks generated by the Segment Anything Model ( SAM ) . Specifically , this
function removes small disconnected regions and holes from the input masks , and then performs Non - Maximum
Suppression ( NMS ) to eliminate any newly created duplicate boxes .
Args :
masks ( torch . Tensor ) : Masks , ( N , H , W ) .
min_area ( int ) : Minimum area threshold .
nms_thresh ( float ) : NMS threshold .
masks ( torch . Tensor ) : A tensor containing the masks to be processed . Shape should be ( N , H , W ) , where N is
the number of masks , H is height , and W is width .
min_area ( int ) : The minimum area below which disconnected regions and holes will be removed . Defaults to 0.
nms_thresh ( float ) : The IoU threshold for the NMS algorithm . Defaults to 0.7 .
Returns :
new_masks ( torch . Tensor ) : New Masks , ( N , H , W ) .
keep ( List [ int ] ) : The indices of the new masks , which can be used to filter
the corresponding boxes .
T ( uple [ torch . Tensor , List [ int ] ] ) :
- new_masks ( torch . Tensor ) : The processed masks with small regions removed . Shape is ( N , H , W ) .
- keep ( List [ int ] ) : The indices of the remaining masks post - NMS , which can be used to filter the boxes .
"""
if len ( masks ) == 0 :
return masks
@ -420,10 +458,6 @@ class Predictor(BasePredictor):
# Recalculate boxes and remove any new duplicates
new_masks = torch . cat ( new_masks , dim = 0 )
boxes = batched_mask_to_box ( new_masks )
keep = torchvision . ops . nms (
boxes . float ( ) ,
torch . as_tensor ( scores ) ,
nms_thresh ,
)
keep = torchvision . ops . nms ( boxes . float ( ) , torch . as_tensor ( scores ) , nms_thresh )
return new_masks [ keep ] . to ( device = masks . device , dtype = masks . dtype ) , keep