|
|
|
@ -18,7 +18,9 @@ from .build import build_sam |
|
|
|
|
|
|
|
|
|
class Predictor(BasePredictor): |
|
|
|
|
|
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None): |
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
|
|
|
|
if overrides is None: |
|
|
|
|
overrides = {} |
|
|
|
|
overrides.update(dict(task='segment', mode='predict', imgsz=1024)) |
|
|
|
|
super().__init__(cfg, overrides, _callbacks) |
|
|
|
|
# SAM needs retina_masks=True, or the results would be a mess. |
|
|
|
@ -90,7 +92,7 @@ class Predictor(BasePredictor): |
|
|
|
|
of masks and H=W=256. These low resolution logits can be passed to |
|
|
|
|
a subsequent iteration as mask input. |
|
|
|
|
""" |
|
|
|
|
if all([i is None for i in [bboxes, points, masks]]): |
|
|
|
|
if all(i is None for i in [bboxes, points, masks]): |
|
|
|
|
return self.generate(im, *args, **kwargs) |
|
|
|
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) |
|
|
|
|
|
|
|
|
@ -284,7 +286,7 @@ class Predictor(BasePredictor): |
|
|
|
|
|
|
|
|
|
return pred_masks, pred_scores, pred_bboxes |
|
|
|
|
|
|
|
|
|
def setup_model(self, model): |
|
|
|
|
def setup_model(self, model, verbose=True): |
|
|
|
|
"""Set up YOLO model with specified thresholds and device.""" |
|
|
|
|
device = select_device(self.args.device) |
|
|
|
|
if model is None: |
|
|
|
@ -306,7 +308,7 @@ class Predictor(BasePredictor): |
|
|
|
|
# (N, 1, H, W), (N, 1) |
|
|
|
|
pred_masks, pred_scores = preds[:2] |
|
|
|
|
pred_bboxes = preds[2] if self.segment_all else None |
|
|
|
|
names = dict(enumerate([str(i) for i in range(len(pred_masks))])) |
|
|
|
|
names = dict(enumerate(str(i) for i in range(len(pred_masks)))) |
|
|
|
|
results = [] |
|
|
|
|
for i, masks in enumerate([pred_masks]): |
|
|
|
|
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs |
|
|
|
|