|
|
|
@ -374,6 +374,10 @@ class Predictor(BasePredictor): |
|
|
|
|
masks (torch.Tensor): Masks, (N, H, W). |
|
|
|
|
min_area (int): Minimum area threshold. |
|
|
|
|
nms_thresh (float): NMS threshold. |
|
|
|
|
Returns: |
|
|
|
|
new_masks (torch.Tensor): New Masks, (N, H, W). |
|
|
|
|
keep (List[int]): The indices of the new masks, which can be used to filter |
|
|
|
|
the corresponding boxes. |
|
|
|
|
""" |
|
|
|
|
if len(masks) == 0: |
|
|
|
|
return masks |
|
|
|
@ -382,7 +386,7 @@ class Predictor(BasePredictor): |
|
|
|
|
new_masks = [] |
|
|
|
|
scores = [] |
|
|
|
|
for mask in masks: |
|
|
|
|
mask = mask.cpu().numpy() |
|
|
|
|
mask = mask.cpu().numpy().astype(np.uint8) |
|
|
|
|
mask, changed = remove_small_regions(mask, min_area, mode='holes') |
|
|
|
|
unchanged = not changed |
|
|
|
|
mask, changed = remove_small_regions(mask, min_area, mode='islands') |
|
|
|
@ -402,9 +406,4 @@ class Predictor(BasePredictor): |
|
|
|
|
nms_thresh, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Only recalculate masks for masks that have changed |
|
|
|
|
for i in keep: |
|
|
|
|
if scores[i] == 0.0: |
|
|
|
|
masks[i] = new_masks[i] |
|
|
|
|
|
|
|
|
|
return masks[keep] |
|
|
|
|
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep |
|
|
|
|