`ultralytics 8.2.63` refactor `FastSAMPredictor` (#14582)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/14608/head v8.2.63
Nguyễn Anh Bình 4 months ago committed by GitHub
parent db82d1c6ae
commit 3637516412
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      .github/workflows/format.yml
  2. 4
      docs/en/reference/models/fastsam/utils.md
  3. 2
      ultralytics/__init__.py
  4. 89
      ultralytics/models/fastsam/predict.py
  5. 42
      ultralytics/models/fastsam/utils.py

@ -5,6 +5,8 @@
name: Ultralytics Actions
on:
issues:
types: [opened, edited]
pull_request_target:
branches: [main]
types: [opened, closed, synchronize]
@ -17,6 +19,7 @@ jobs:
uses: ultralytics/actions@main
with:
token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify
labels: true # autolabel issues and PRs
python: true # format Python code and docstrings
markdown: true # format Markdown
prettier: true # format YAML

@ -13,8 +13,4 @@ keywords: FastSAM, bounding boxes, IoU, Ultralytics, image processing, computer
## ::: ultralytics.models.fastsam.utils.adjust_bboxes_to_image_border
<br><br><hr><br>
## ::: ultralytics.models.fastsam.utils.bbox_iou
<br><br>

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.62"
__version__ = "8.2.63"
import os

@ -1,84 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.results import Results
from ultralytics.models.fastsam.utils import bbox_iou
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.models.yolo.segment import SegmentationPredictor
from ultralytics.utils.metrics import box_iou
from .utils import adjust_bboxes_to_image_border
class FastSAMPredictor(DetectionPredictor):
class FastSAMPredictor(SegmentationPredictor):
"""
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
YOLO framework.
This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
for single-class segmentation.
Attributes:
cfg (dict): Configuration parameters for prediction.
overrides (dict, optional): Optional parameter overrides for custom behavior.
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
class segmentation.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.
Args:
cfg (dict): Configuration parameters for prediction.
overrides (dict, optional): Optional parameter overrides for custom behavior.
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""
Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
size, and returns the final results.
Args:
preds (list): The raw output predictions from the model.
img (torch.Tensor): The processed image tensor.
orig_imgs (list | torch.Tensor): The original image or list of images.
Returns:
(list): A list of Results objects, each containing processed boxes, masks, and other metadata.
"""
p = ops.non_max_suppression(
preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=1, # set to 1 class since SAM has no class predictions
classes=self.args.classes,
)
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
if critical_iou_index.numel() != 0:
full_box[0][4] = p[0][critical_iou_index][:, 4]
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
p[0][critical_iou_index] = full_box
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
if not len(pred): # save empty boxes
masks = None
elif self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
"""Applies box postprocess for FastSAM predictions."""
results = super().postprocess(preds, img, orig_imgs)
for result in results:
full_box = torch.tensor(
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
)
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
if idx.numel() != 0:
result.boxes.xyxy[idx] = full_box
return results

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
@ -25,43 +23,3 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
return boxes
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
"""
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1 (torch.Tensor): (4, )
boxes (torch.Tensor): (n, 4)
iou_thres (float): IoU threshold
image_shape (tuple): (height, width)
raw_output (bool): If True, return the raw IoU values instead of the indices
Returns:
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
"""
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# Obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
y1 = torch.max(box1[1], boxes[:, 1])
x2 = torch.min(box1[2], boxes[:, 2])
y2 = torch.min(box1[3], boxes[:, 3])
# Compute the area of intersection
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
# Compute the area of both individual boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# Compute the area of union
union = box1_area + box2_area - intersection
# Compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
return 0 if iou.numel() == 0 else iou
# return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()

Loading…
Cancel
Save