|
|
|
@ -450,16 +450,18 @@ class Predictor(BasePredictor): |
|
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): |
|
|
|
|
if pred_bboxes is not None: |
|
|
|
|
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) |
|
|
|
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) |
|
|
|
|
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) |
|
|
|
|
|
|
|
|
|
if len(masks) == 0: |
|
|
|
|
masks = None |
|
|
|
|
else: |
|
|
|
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] |
|
|
|
|
masks = masks > self.model.mask_threshold # to bool |
|
|
|
|
if pred_bboxes is not None: |
|
|
|
|
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) |
|
|
|
|
else: |
|
|
|
|
pred_bboxes = batched_mask_to_box(masks) |
|
|
|
|
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. |
|
|
|
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) |
|
|
|
|
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) |
|
|
|
|
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) |
|
|
|
|
# Reset segment-all mode. |
|
|
|
|
self.segment_all = False |
|
|
|
|