diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index ccc4c10391..32180c5e5b 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -5,6 +5,8 @@ name: Ultralytics Actions on: + issues: + types: [opened, edited] pull_request_target: branches: [main] types: [opened, closed, synchronize] @@ -17,6 +19,7 @@ jobs: uses: ultralytics/actions@main with: token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify + labels: true # autolabel issues and PRs python: true # format Python code and docstrings markdown: true # format Markdown prettier: true # format YAML diff --git a/docs/en/reference/models/fastsam/utils.md b/docs/en/reference/models/fastsam/utils.md index 43c5617c21..14695908d9 100644 --- a/docs/en/reference/models/fastsam/utils.md +++ b/docs/en/reference/models/fastsam/utils.md @@ -13,8 +13,4 @@ keywords: FastSAM, bounding boxes, IoU, Ultralytics, image processing, computer ## ::: ultralytics.models.fastsam.utils.adjust_bboxes_to_image_border -



- -## ::: ultralytics.models.fastsam.utils.bbox_iou -

diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 3a4ab20152..362d4eead9 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.62" +__version__ = "8.2.63" import os diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py index f7ffb2faa3..023c1f9ab8 100644 --- a/ultralytics/models/fastsam/predict.py +++ b/ultralytics/models/fastsam/predict.py @@ -1,84 +1,31 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license - import torch -from ultralytics.engine.results import Results -from ultralytics.models.fastsam.utils import bbox_iou -from ultralytics.models.yolo.detect.predict import DetectionPredictor -from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.models.yolo.segment import SegmentationPredictor +from ultralytics.utils.metrics import box_iou + +from .utils import adjust_bboxes_to_image_border -class FastSAMPredictor(DetectionPredictor): +class FastSAMPredictor(SegmentationPredictor): """ FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics YOLO framework. - This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM. - It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing - for single-class segmentation. - - Attributes: - cfg (dict): Configuration parameters for prediction. - overrides (dict, optional): Optional parameter overrides for custom behavior. - _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. + This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It + adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single- + class segmentation. """ - def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): - """ - Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'. - - Args: - cfg (dict): Configuration parameters for prediction. - overrides (dict, optional): Optional parameter overrides for custom behavior. - _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction. - """ - super().__init__(cfg, overrides, _callbacks) - self.args.task = "segment" - def postprocess(self, preds, img, orig_imgs): - """ - Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image - size, and returns the final results. - - Args: - preds (list): The raw output predictions from the model. - img (torch.Tensor): The processed image tensor. - orig_imgs (list | torch.Tensor): The original image or list of images. - - Returns: - (list): A list of Results objects, each containing processed boxes, masks, and other metadata. - """ - p = ops.non_max_suppression( - preds[0], - self.args.conf, - self.args.iou, - agnostic=self.args.agnostic_nms, - max_det=self.args.max_det, - nc=1, # set to 1 class since SAM has no class predictions - classes=self.args.classes, - ) - full_box = torch.zeros(p[0].shape[1], device=p[0].device) - full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 - full_box = full_box.view(1, -1) - critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) - if critical_iou_index.numel() != 0: - full_box[0][4] = p[0][critical_iou_index][:, 4] - full_box[0][6:] = p[0][critical_iou_index][:, 6:] - p[0][critical_iou_index] = full_box - - if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list - orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) - - results = [] - proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported - for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])): - if not len(pred): # save empty boxes - masks = None - elif self.args.retina_masks: - pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) - masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC - else: - masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC - pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) - results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) + """Applies box postprocess for FastSAM predictions.""" + results = super().postprocess(preds, img, orig_imgs) + for result in results: + full_box = torch.tensor( + [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32 + ) + boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape) + idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten() + if idx.numel() != 0: + result.boxes.xyxy[idx] = full_box return results diff --git a/ultralytics/models/fastsam/utils.py b/ultralytics/models/fastsam/utils.py index 480e903942..5427083e35 100644 --- a/ultralytics/models/fastsam/utils.py +++ b/ultralytics/models/fastsam/utils.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -import torch - def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): """ @@ -25,43 +23,3 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): boxes[boxes[:, 2] > w - threshold, 2] = w # x2 boxes[boxes[:, 3] > h - threshold, 3] = h # y2 return boxes - - -def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): - """ - Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. - - Args: - box1 (torch.Tensor): (4, ) - boxes (torch.Tensor): (n, 4) - iou_thres (float): IoU threshold - image_shape (tuple): (height, width) - raw_output (bool): If True, return the raw IoU values instead of the indices - - Returns: - high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres - """ - boxes = adjust_bboxes_to_image_border(boxes, image_shape) - # Obtain coordinates for intersections - x1 = torch.max(box1[0], boxes[:, 0]) - y1 = torch.max(box1[1], boxes[:, 1]) - x2 = torch.min(box1[2], boxes[:, 2]) - y2 = torch.min(box1[3], boxes[:, 3]) - - # Compute the area of intersection - intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) - - # Compute the area of both individual boxes - box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) - box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - - # Compute the area of union - union = box1_area + box2_area - intersection - - # Compute the IoU - iou = intersection / union # Should be shape (n, ) - if raw_output: - return 0 if iou.numel() == 0 else iou - - # return indices of boxes with IoU > thres - return torch.nonzero(iou > iou_thres).flatten()