diff --git a/tests/test_cli.py b/tests/test_cli.py index 51d2c123d..b4a09fcf3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -75,6 +75,7 @@ def test_fastsam(task='segment', model=WEIGHTS_DIR / 'FastSAM-s.pt', data='coco8 from ultralytics import FastSAM from ultralytics.models.fastsam import FastSAMPrompt + from ultralytics.models.sam import Predictor # Create a FastSAM model sam_model = FastSAM(model) # or FastSAM-x.pt @@ -82,6 +83,9 @@ def test_fastsam(task='segment', model=WEIGHTS_DIR / 'FastSAM-s.pt', data='coco8 # Run inference on an image everything_results = sam_model(source, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9) + # Remove small regions + new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20) + # Everything prompt prompt_process = FastSAMPrompt(source, everything_results, device='cpu') ann = prompt_process.everything_prompt() diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 221d3ce4e..5f0a97894 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -374,6 +374,10 @@ class Predictor(BasePredictor): masks (torch.Tensor): Masks, (N, H, W). min_area (int): Minimum area threshold. nms_thresh (float): NMS threshold. + Returns: + new_masks (torch.Tensor): New Masks, (N, H, W). + keep (List[int]): The indices of the new masks, which can be used to filter + the corresponding boxes. """ if len(masks) == 0: return masks @@ -382,7 +386,7 @@ class Predictor(BasePredictor): new_masks = [] scores = [] for mask in masks: - mask = mask.cpu().numpy() + mask = mask.cpu().numpy().astype(np.uint8) mask, changed = remove_small_regions(mask, min_area, mode='holes') unchanged = not changed mask, changed = remove_small_regions(mask, min_area, mode='islands') @@ -402,9 +406,4 @@ class Predictor(BasePredictor): nms_thresh, ) - # Only recalculate masks for masks that have changed - for i in keep: - if scores[i] == 0.0: - masks[i] = new_masks[i] - - return masks[keep] + return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep