diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py index 5c9cb21c0..36222c862 100644 --- a/ultralytics/models/fastsam/predict.py +++ b/ultralytics/models/fastsam/predict.py @@ -22,7 +22,7 @@ class FastSAMPredictor(DetectionPredictor): max_det=self.args.max_det, nc=len(self.model.names), classes=self.args.classes) - full_box = torch.zeros_like(p[0][0]) + full_box = torch.zeros(p[0].shape[1]) full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 full_box = full_box.view(1, -1) critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 3bdc9448e..79266e1e5 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -8,18 +8,17 @@ import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image +from tqdm import tqdm -from ultralytics.utils import LOGGER +from ultralytics.utils import TQDM_BAR_FORMAT class FastSAMPrompt: - def __init__(self, img_path, results, device='cuda') -> None: - # self.img_path = img_path + def __init__(self, source, results, device='cuda') -> None: self.device = device self.results = results - self.img_path = str(img_path) - self.ori_img = cv2.imread(self.img_path) + self.source = source # Import and assign clip try: @@ -48,7 +47,7 @@ class FastSAMPrompt: @staticmethod def _format_results(result, filter=0): annotations = [] - n = len(result.masks.data) + n = len(result.masks.data) if result.masks is not None else 0 for i in range(n): mask = result.masks.data[i] == 1.0 if torch.sum(mask) >= filter: @@ -86,69 +85,79 @@ class FastSAMPrompt: mask_random_color=True, better_quality=True, retina=False, - with_countouers=True): - if isinstance(annotations[0], dict): - annotations = [annotation['segmentation'] for annotation in annotations] - if isinstance(annotations, torch.Tensor): - annotations = annotations.cpu().numpy() - result_name = os.path.basename(self.img_path) - image = self.ori_img - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - original_h = image.shape[0] - original_w = image.shape[1] - # for macOS only - # plt.switch_backend('TkAgg') - fig = plt.figure(figsize=(original_w / 100, original_h / 100)) - # Add subplot with no margin. - plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) - plt.margins(0, 0) - plt.gca().xaxis.set_major_locator(plt.NullLocator()) - plt.gca().yaxis.set_major_locator(plt.NullLocator()) - - plt.imshow(image) - if better_quality: - for i, mask in enumerate(annotations): - mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) - annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) - self.fast_show_mask( - annotations, - plt.gca(), - random_color=mask_random_color, - bbox=bbox, - points=points, - pointlabel=point_label, - retinamask=retina, - target_height=original_h, - target_width=original_w, - ) - - if with_countouers: - contour_all = [] - temp = np.zeros((original_h, original_w, 1)) - for i, mask in enumerate(annotations): - if isinstance(mask, dict): - mask = mask['segmentation'] - annotation = mask.astype(np.uint8) - if not retina: - annotation = cv2.resize( - annotation, - (original_w, original_h), - interpolation=cv2.INTER_NEAREST, - ) - contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) - contour_all.extend(iter(contours)) - cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) - color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) - contour_mask = temp / 255 * color.reshape(1, 1, -1) - plt.imshow(contour_mask) - - save_path = Path(output) / result_name - save_path.parent.mkdir(exist_ok=True, parents=True) - plt.axis('off') - fig.savefig(save_path) - LOGGER.info(f'Saved to {save_path.absolute()}') - - # CPU post process + withContours=True): + n = len(annotations) + pbar = tqdm(annotations, total=n, bar_format=TQDM_BAR_FORMAT) + for ann in pbar: + result_name = os.path.basename(ann.path) + image = ann.orig_img + original_h, original_w = ann.orig_shape + # for macOS only + # plt.switch_backend('TkAgg') + plt.figure(figsize=(original_w / 100, original_h / 100)) + # Add subplot with no margin. + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.imshow(image) + + if ann.masks is not None: + masks = ann.masks.data + if better_quality: + if isinstance(masks[0], torch.Tensor): + masks = np.array(masks.cpu()) + for i, mask in enumerate(masks): + mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) + masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) + + self.fast_show_mask( + masks, + plt.gca(), + random_color=mask_random_color, + bbox=bbox, + points=points, + pointlabel=point_label, + retinamask=retina, + target_height=original_h, + target_width=original_w, + ) + + if withContours: + contour_all = [] + temp = np.zeros((original_h, original_w, 1)) + for i, mask in enumerate(masks): + mask = mask.astype(np.uint8) + if not retina: + mask = cv2.resize( + mask, + (original_w, original_h), + interpolation=cv2.INTER_NEAREST, + ) + contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + contour_all.extend(iter(contours)) + cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) + color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) + contour_mask = temp / 255 * color.reshape(1, 1, -1) + plt.imshow(contour_mask) + + plt.axis('off') + fig = plt.gcf() + + try: + buf = fig.canvas.tostring_rgb() + except AttributeError: + fig.canvas.draw() + buf = fig.canvas.tostring_rgb() + cols, rows = fig.canvas.get_width_height() + img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) + + save_path = Path(output) / result_name + save_path.parent.mkdir(exist_ok=True, parents=True) + cv2.imwrite(str(save_path), img_array) + plt.close() + pbar.set_description('Saving {} to {}'.format(result_name, save_path)) + @staticmethod def fast_show_mask( annotation, @@ -215,8 +224,9 @@ class FastSAMPrompt: return probs[:, 0].softmax(dim=0) def _crop_image(self, format_results): - - image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB)) + if os.path.isdir(self.source): + raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") + image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB)) ori_w, ori_h = image.size annotations = format_results mask_h, mask_w = annotations[0]['segmentation'].shape @@ -237,65 +247,71 @@ class FastSAMPrompt: return cropped_boxes, cropped_images, not_crop, filter_id, annotations def box_prompt(self, bbox): - - assert (bbox[2] != 0 and bbox[3] != 0) - masks = self.results[0].masks.data - target_height = self.ori_img.shape[0] - target_width = self.ori_img.shape[1] - h = masks.shape[1] - w = masks.shape[2] - if h != target_height or w != target_width: - bbox = [ - int(bbox[0] * w / target_width), - int(bbox[1] * h / target_height), - int(bbox[2] * w / target_width), - int(bbox[3] * h / target_height), ] - bbox[0] = max(round(bbox[0]), 0) - bbox[1] = max(round(bbox[1]), 0) - bbox[2] = min(round(bbox[2]), w) - bbox[3] = min(round(bbox[3]), h) - - # IoUs = torch.zeros(len(masks), dtype=torch.float32) - bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) - - masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) - orig_masks_area = torch.sum(masks, dim=(1, 2)) - - union = bbox_area + orig_masks_area - masks_area - IoUs = masks_area / union - max_iou_index = torch.argmax(IoUs) - - return np.array([masks[max_iou_index].cpu().numpy()]) + if self.results[0].masks is not None: + assert (bbox[2] != 0 and bbox[3] != 0) + if os.path.isdir(self.source): + raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") + masks = self.results[0].masks.data + target_height, target_width = self.results[0].orig_shape + h = masks.shape[1] + w = masks.shape[2] + if h != target_height or w != target_width: + bbox = [ + int(bbox[0] * w / target_width), + int(bbox[1] * h / target_height), + int(bbox[2] * w / target_width), + int(bbox[3] * h / target_height), ] + bbox[0] = max(round(bbox[0]), 0) + bbox[1] = max(round(bbox[1]), 0) + bbox[2] = min(round(bbox[2]), w) + bbox[3] = min(round(bbox[3]), h) + + # IoUs = torch.zeros(len(masks), dtype=torch.float32) + bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) + + masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) + orig_masks_area = torch.sum(masks, dim=(1, 2)) + + union = bbox_area + orig_masks_area - masks_area + IoUs = masks_area / union + max_iou_index = torch.argmax(IoUs) + + self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()])) + return self.results def point_prompt(self, points, pointlabel): # numpy 处理 - - masks = self._format_results(self.results[0], 0) - target_height = self.ori_img.shape[0] - target_width = self.ori_img.shape[1] - h = masks[0]['segmentation'].shape[0] - w = masks[0]['segmentation'].shape[1] - if h != target_height or w != target_width: - points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] - onemask = np.zeros((h, w)) - for i, annotation in enumerate(masks): - mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation - for i, point in enumerate(points): - if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: - onemask += mask - if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: - onemask -= mask - onemask = onemask >= 1 - return np.array([onemask]) + if self.results[0].masks is not None: + if os.path.isdir(self.source): + raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.") + masks = self._format_results(self.results[0], 0) + target_height, target_width = self.results[0].orig_shape + h = masks[0]['segmentation'].shape[0] + w = masks[0]['segmentation'].shape[1] + if h != target_height or w != target_width: + points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] + onemask = np.zeros((h, w)) + for i, annotation in enumerate(masks): + mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation + for i, point in enumerate(points): + if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: + onemask += mask + if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: + onemask -= mask + onemask = onemask >= 1 + self.results[0].masks.data = torch.tensor(np.array([onemask])) + return self.results def text_prompt(self, text): - format_results = self._format_results(self.results[0], 0) - cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) - clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device) - scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) - max_idx = scores.argsort() - max_idx = max_idx[-1] - max_idx += sum(np.array(filter_id) <= int(max_idx)) - return np.array([annotations[max_idx]['segmentation']]) + if self.results[0].masks is not None: + format_results = self._format_results(self.results[0], 0) + cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) + clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device) + scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) + max_idx = scores.argsort() + max_idx = max_idx[-1] + max_idx += sum(np.array(filter_id) <= int(max_idx)) + self.results[0].masks.data = torch.tensor(np.array([ann['segmentation'] for ann in annotations])) + return self.results def everything_prompt(self): - return self.results[0].masks.data + return self.results