|
|
|
@ -95,7 +95,7 @@ class Predictor(BasePredictor): |
|
|
|
|
""" |
|
|
|
|
if overrides is None: |
|
|
|
|
overrides = {} |
|
|
|
|
overrides.update(dict(task="segment", mode="predict", imgsz=1024)) |
|
|
|
|
overrides.update(dict(task="segment", mode="predict")) |
|
|
|
|
super().__init__(cfg, overrides, _callbacks) |
|
|
|
|
self.args.retina_masks = True |
|
|
|
|
self.im = None |
|
|
|
@ -455,8 +455,11 @@ class Predictor(BasePredictor): |
|
|
|
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) |
|
|
|
|
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) |
|
|
|
|
|
|
|
|
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] |
|
|
|
|
masks = masks > self.model.mask_threshold # to bool |
|
|
|
|
if len(masks) == 0: |
|
|
|
|
masks = None |
|
|
|
|
else: |
|
|
|
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] |
|
|
|
|
masks = masks > self.model.mask_threshold # to bool |
|
|
|
|
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) |
|
|
|
|
# Reset segment-all mode. |
|
|
|
|
self.segment_all = False |
|
|
|
@ -522,6 +525,10 @@ class Predictor(BasePredictor): |
|
|
|
|
|
|
|
|
|
def get_im_features(self, im): |
|
|
|
|
"""Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" |
|
|
|
|
assert ( |
|
|
|
|
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] |
|
|
|
|
), f"SAM models only support square image size, but got {self.imgsz}." |
|
|
|
|
self.model.set_imgsz(self.imgsz) |
|
|
|
|
return self.model.image_encoder(im) |
|
|
|
|
|
|
|
|
|
def set_prompts(self, prompts): |
|
|
|
@ -761,6 +768,12 @@ class SAM2Predictor(Predictor): |
|
|
|
|
|
|
|
|
|
def get_im_features(self, im): |
|
|
|
|
"""Extracts image features from the SAM image encoder for subsequent processing.""" |
|
|
|
|
assert ( |
|
|
|
|
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1] |
|
|
|
|
), f"SAM 2 models only support square image size, but got {self.imgsz}." |
|
|
|
|
self.model.set_imgsz(self.imgsz) |
|
|
|
|
self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] |
|
|
|
|
|
|
|
|
|
backbone_out = self.model.forward_image(im) |
|
|
|
|
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) |
|
|
|
|
if self.model.directly_add_no_mem_embed: |
|
|
|
|