Add SAM Predictor `remove_small_regions` test (#4576)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/4579/head
Laughing 1 year ago committed by GitHub
parent b4dca690d4
commit e9f596430f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      tests/test_cli.py
  2. 13
      ultralytics/models/sam/predict.py

@ -75,6 +75,7 @@ def test_fastsam(task='segment', model=WEIGHTS_DIR / 'FastSAM-s.pt', data='coco8
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
from ultralytics.models.sam import Predictor
# Create a FastSAM model
sam_model = FastSAM(model) # or FastSAM-x.pt
@ -82,6 +83,9 @@ def test_fastsam(task='segment', model=WEIGHTS_DIR / 'FastSAM-s.pt', data='coco8
# Run inference on an image
everything_results = sam_model(source, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
# Remove small regions
new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20)
# Everything prompt
prompt_process = FastSAMPrompt(source, everything_results, device='cpu')
ann = prompt_process.everything_prompt()

@ -374,6 +374,10 @@ class Predictor(BasePredictor):
masks (torch.Tensor): Masks, (N, H, W).
min_area (int): Minimum area threshold.
nms_thresh (float): NMS threshold.
Returns:
new_masks (torch.Tensor): New Masks, (N, H, W).
keep (List[int]): The indices of the new masks, which can be used to filter
the corresponding boxes.
"""
if len(masks) == 0:
return masks
@ -382,7 +386,7 @@ class Predictor(BasePredictor):
new_masks = []
scores = []
for mask in masks:
mask = mask.cpu().numpy()
mask = mask.cpu().numpy().astype(np.uint8)
mask, changed = remove_small_regions(mask, min_area, mode='holes')
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode='islands')
@ -402,9 +406,4 @@ class Predictor(BasePredictor):
nms_thresh,
)
# Only recalculate masks for masks that have changed
for i in keep:
if scores[i] == 0.0:
masks[i] = new_masks[i]
return masks[keep]
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep

Loading…
Cancel
Save