`ultralytics 8.0.127` add FastSAM model (#3390)
Co-authored-by: dingwenchao <12962189468@163.com> Co-authored-by: 丁文超 <dingwenchao@dingwenchaodeMacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>main
parent
91905b4b0b
commit
400f3f72a1
8 changed files with 942 additions and 6 deletions
@ -1,13 +1,14 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
__version__ = '8.0.126' |
||||
__version__ = '8.0.127' |
||||
|
||||
from ultralytics.hub import start |
||||
from ultralytics.vit.rtdetr import RTDETR |
||||
from ultralytics.vit.sam import SAM |
||||
from ultralytics.yolo.engine.model import YOLO |
||||
from ultralytics.yolo.fastsam import FastSAM |
||||
from ultralytics.yolo.nas import NAS |
||||
from ultralytics.yolo.utils.checks import check_yolo as checks |
||||
from ultralytics.yolo.utils.downloads import download |
||||
|
||||
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start', 'download' # allow simpler import |
||||
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start', 'download', 'FastSAM' # allow simpler import |
||||
|
@ -0,0 +1,8 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from .model import FastSAM |
||||
from .predict import FastSAMPredictor |
||||
from .prompt import FastSAMPrompt |
||||
from .val import FastSAMValidator |
||||
|
||||
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator' |
@ -0,0 +1,51 @@ |
||||
import torch |
||||
|
||||
from ultralytics.yolo.engine.results import Results |
||||
from ultralytics.yolo.fastsam.utils import bbox_iou |
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ops |
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor |
||||
|
||||
|
||||
class FastSAMPredictor(DetectionPredictor): |
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
||||
super().__init__(cfg, overrides, _callbacks) |
||||
self.args.task = 'segment' |
||||
|
||||
def postprocess(self, preds, img, orig_imgs): |
||||
"""TODO: filter by classes.""" |
||||
p = ops.non_max_suppression(preds[0], |
||||
self.args.conf, |
||||
self.args.iou, |
||||
agnostic=self.args.agnostic_nms, |
||||
max_det=self.args.max_det, |
||||
nc=len(self.model.names), |
||||
classes=self.args.classes) |
||||
full_box = torch.zeros_like(p[0][0]) |
||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 |
||||
full_box = full_box.view(1, -1) |
||||
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) |
||||
if critical_iou_index.numel() != 0: |
||||
full_box[0][4] = p[0][critical_iou_index][:, 4] |
||||
full_box[0][6:] = p[0][critical_iou_index][:, 6:] |
||||
p[0][critical_iou_index] = full_box |
||||
results = [] |
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported |
||||
for i, pred in enumerate(p): |
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs |
||||
path = self.batch[0] |
||||
img_path = path[i] if isinstance(path, list) else path |
||||
if not len(pred): # save empty boxes |
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) |
||||
continue |
||||
if self.args.retina_masks: |
||||
if not isinstance(orig_imgs, torch.Tensor): |
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) |
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC |
||||
else: |
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC |
||||
if not isinstance(orig_imgs, torch.Tensor): |
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) |
||||
results.append( |
||||
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) |
||||
return results |
@ -0,0 +1,406 @@ |
||||
import os |
||||
|
||||
import cv2 |
||||
import matplotlib.pyplot as plt |
||||
import numpy as np |
||||
import torch |
||||
from PIL import Image |
||||
|
||||
try: |
||||
import clip # for linear_assignment |
||||
|
||||
except (ImportError, AssertionError, AttributeError): |
||||
from ultralytics.yolo.utils.checks import check_requirements |
||||
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source |
||||
import clip |
||||
|
||||
|
||||
class FastSAMPrompt: |
||||
|
||||
def __init__(self, img_path, results, device='cuda') -> None: |
||||
# self.img_path = img_path |
||||
self.device = device |
||||
self.results = results |
||||
self.img_path = img_path |
||||
self.ori_img = cv2.imread(img_path) |
||||
|
||||
def _segment_image(self, image, bbox): |
||||
image_array = np.array(image) |
||||
segmented_image_array = np.zeros_like(image_array) |
||||
x1, y1, x2, y2 = bbox |
||||
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] |
||||
segmented_image = Image.fromarray(segmented_image_array) |
||||
black_image = Image.new('RGB', image.size, (255, 255, 255)) |
||||
# transparency_mask = np.zeros_like((), dtype=np.uint8) |
||||
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) |
||||
transparency_mask[y1:y2, x1:x2] = 255 |
||||
transparency_mask_image = Image.fromarray(transparency_mask, mode='L') |
||||
black_image.paste(segmented_image, mask=transparency_mask_image) |
||||
return black_image |
||||
|
||||
def _format_results(self, result, filter=0): |
||||
annotations = [] |
||||
n = len(result.masks.data) |
||||
for i in range(n): |
||||
annotation = {} |
||||
mask = result.masks.data[i] == 1.0 |
||||
|
||||
if torch.sum(mask) < filter: |
||||
continue |
||||
annotation['id'] = i |
||||
annotation['segmentation'] = mask.cpu().numpy() |
||||
annotation['bbox'] = result.boxes.data[i] |
||||
annotation['score'] = result.boxes.conf[i] |
||||
annotation['area'] = annotation['segmentation'].sum() |
||||
annotations.append(annotation) |
||||
return annotations |
||||
|
||||
def filter_masks(annotations): # filte the overlap mask |
||||
annotations.sort(key=lambda x: x['area'], reverse=True) |
||||
to_remove = set() |
||||
for i in range(0, len(annotations)): |
||||
a = annotations[i] |
||||
for j in range(i + 1, len(annotations)): |
||||
b = annotations[j] |
||||
if i != j and j not in to_remove: |
||||
# check if |
||||
if b['area'] < a['area']: |
||||
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8: |
||||
to_remove.add(j) |
||||
|
||||
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove |
||||
|
||||
def _get_bbox_from_mask(self, mask): |
||||
mask = mask.astype(np.uint8) |
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
||||
x1, y1, w, h = cv2.boundingRect(contours[0]) |
||||
x2, y2 = x1 + w, y1 + h |
||||
if len(contours) > 1: |
||||
for b in contours: |
||||
x_t, y_t, w_t, h_t = cv2.boundingRect(b) |
||||
# 将多个bbox合并成一个 |
||||
x1 = min(x1, x_t) |
||||
y1 = min(y1, y_t) |
||||
x2 = max(x2, x_t + w_t) |
||||
y2 = max(y2, y_t + h_t) |
||||
h = y2 - y1 |
||||
w = x2 - x1 |
||||
return [x1, y1, x2, y2] |
||||
|
||||
def plot(self, |
||||
annotations, |
||||
output, |
||||
bbox=None, |
||||
points=None, |
||||
point_label=None, |
||||
mask_random_color=True, |
||||
better_quality=True, |
||||
retina=False, |
||||
withContours=True): |
||||
if isinstance(annotations[0], dict): |
||||
annotations = [annotation['segmentation'] for annotation in annotations] |
||||
result_name = os.path.basename(self.img_path) |
||||
image = self.ori_img |
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
||||
original_h = image.shape[0] |
||||
original_w = image.shape[1] |
||||
# for MacOS only |
||||
# plt.switch_backend('TkAgg') |
||||
plt.figure(figsize=(original_w / 100, original_h / 100)) |
||||
# Add subplot with no margin. |
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) |
||||
plt.margins(0, 0) |
||||
plt.gca().xaxis.set_major_locator(plt.NullLocator()) |
||||
plt.gca().yaxis.set_major_locator(plt.NullLocator()) |
||||
|
||||
plt.imshow(image) |
||||
if better_quality: |
||||
if isinstance(annotations[0], torch.Tensor): |
||||
annotations = np.array(annotations.cpu()) |
||||
for i, mask in enumerate(annotations): |
||||
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) |
||||
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) |
||||
if self.device == 'cpu': |
||||
annotations = np.array(annotations) |
||||
self.fast_show_mask( |
||||
annotations, |
||||
plt.gca(), |
||||
random_color=mask_random_color, |
||||
bbox=bbox, |
||||
points=points, |
||||
pointlabel=point_label, |
||||
retinamask=retina, |
||||
target_height=original_h, |
||||
target_width=original_w, |
||||
) |
||||
else: |
||||
if isinstance(annotations[0], np.ndarray): |
||||
annotations = torch.from_numpy(annotations) |
||||
self.fast_show_mask_gpu( |
||||
annotations, |
||||
plt.gca(), |
||||
random_color=mask_random_color, |
||||
bbox=bbox, |
||||
points=points, |
||||
pointlabel=point_label, |
||||
retinamask=retina, |
||||
target_height=original_h, |
||||
target_width=original_w, |
||||
) |
||||
if isinstance(annotations, torch.Tensor): |
||||
annotations = annotations.cpu().numpy() |
||||
if withContours: |
||||
contour_all = [] |
||||
temp = np.zeros((original_h, original_w, 1)) |
||||
for i, mask in enumerate(annotations): |
||||
if type(mask) == dict: |
||||
mask = mask['segmentation'] |
||||
annotation = mask.astype(np.uint8) |
||||
if not retina: |
||||
annotation = cv2.resize( |
||||
annotation, |
||||
(original_w, original_h), |
||||
interpolation=cv2.INTER_NEAREST, |
||||
) |
||||
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
||||
for contour in contours: |
||||
contour_all.append(contour) |
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) |
||||
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) |
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1) |
||||
plt.imshow(contour_mask) |
||||
|
||||
save_path = output |
||||
if not os.path.exists(save_path): |
||||
os.makedirs(save_path) |
||||
plt.axis('off') |
||||
fig = plt.gcf() |
||||
plt.draw() |
||||
|
||||
try: |
||||
buf = fig.canvas.tostring_rgb() |
||||
except AttributeError: |
||||
fig.canvas.draw() |
||||
buf = fig.canvas.tostring_rgb() |
||||
cols, rows = fig.canvas.get_width_height() |
||||
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) |
||||
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) |
||||
|
||||
# CPU post process |
||||
def fast_show_mask( |
||||
self, |
||||
annotation, |
||||
ax, |
||||
random_color=False, |
||||
bbox=None, |
||||
points=None, |
||||
pointlabel=None, |
||||
retinamask=True, |
||||
target_height=960, |
||||
target_width=960, |
||||
): |
||||
msak_sum = annotation.shape[0] |
||||
height = annotation.shape[1] |
||||
weight = annotation.shape[2] |
||||
# 将annotation 按照面积 排序 |
||||
areas = np.sum(annotation, axis=(1, 2)) |
||||
sorted_indices = np.argsort(areas) |
||||
annotation = annotation[sorted_indices] |
||||
|
||||
index = (annotation != 0).argmax(axis=0) |
||||
if random_color: |
||||
color = np.random.random((msak_sum, 1, 1, 3)) |
||||
else: |
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) |
||||
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 |
||||
visual = np.concatenate([color, transparency], axis=-1) |
||||
mask_image = np.expand_dims(annotation, -1) * visual |
||||
|
||||
show = np.zeros((height, weight, 4)) |
||||
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') |
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) |
||||
# 使用向量化索引更新show的值 |
||||
show[h_indices, w_indices, :] = mask_image[indices] |
||||
if bbox is not None: |
||||
x1, y1, x2, y2 = bbox |
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) |
||||
# draw point |
||||
if points is not None: |
||||
plt.scatter( |
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1], |
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1], |
||||
s=20, |
||||
c='y', |
||||
) |
||||
plt.scatter( |
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0], |
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0], |
||||
s=20, |
||||
c='m', |
||||
) |
||||
|
||||
if not retinamask: |
||||
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST) |
||||
ax.imshow(show) |
||||
|
||||
def fast_show_mask_gpu( |
||||
self, |
||||
annotation, |
||||
ax, |
||||
random_color=False, |
||||
bbox=None, |
||||
points=None, |
||||
pointlabel=None, |
||||
retinamask=True, |
||||
target_height=960, |
||||
target_width=960, |
||||
): |
||||
msak_sum = annotation.shape[0] |
||||
height = annotation.shape[1] |
||||
weight = annotation.shape[2] |
||||
areas = torch.sum(annotation, dim=(1, 2)) |
||||
sorted_indices = torch.argsort(areas, descending=False) |
||||
annotation = annotation[sorted_indices] |
||||
# 找每个位置第一个非零值下标 |
||||
index = (annotation != 0).to(torch.long).argmax(dim=0) |
||||
if random_color: |
||||
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) |
||||
else: |
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ |
||||
30 / 255, 144 / 255, 255 / 255]).to(annotation.device) |
||||
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 |
||||
visual = torch.cat([color, transparency], dim=-1) |
||||
mask_image = torch.unsqueeze(annotation, -1) * visual |
||||
# 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 |
||||
show = torch.zeros((height, weight, 4)).to(annotation.device) |
||||
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') |
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) |
||||
# 使用向量化索引更新show的值 |
||||
show[h_indices, w_indices, :] = mask_image[indices] |
||||
show_cpu = show.cpu().numpy() |
||||
if bbox is not None: |
||||
x1, y1, x2, y2 = bbox |
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) |
||||
# draw point |
||||
if points is not None: |
||||
plt.scatter( |
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1], |
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1], |
||||
s=20, |
||||
c='y', |
||||
) |
||||
plt.scatter( |
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0], |
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0], |
||||
s=20, |
||||
c='m', |
||||
) |
||||
if not retinamask: |
||||
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST) |
||||
ax.imshow(show_cpu) |
||||
|
||||
# clip |
||||
@torch.no_grad() |
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: |
||||
preprocessed_images = [preprocess(image).to(device) for image in elements] |
||||
tokenized_text = clip.tokenize([search_text]).to(device) |
||||
stacked_images = torch.stack(preprocessed_images) |
||||
image_features = model.encode_image(stacked_images) |
||||
text_features = model.encode_text(tokenized_text) |
||||
image_features /= image_features.norm(dim=-1, keepdim=True) |
||||
text_features /= text_features.norm(dim=-1, keepdim=True) |
||||
probs = 100.0 * image_features @ text_features.T |
||||
return probs[:, 0].softmax(dim=0) |
||||
|
||||
def _crop_image(self, format_results): |
||||
|
||||
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB)) |
||||
ori_w, ori_h = image.size |
||||
annotations = format_results |
||||
mask_h, mask_w = annotations[0]['segmentation'].shape |
||||
if ori_w != mask_w or ori_h != mask_h: |
||||
image = image.resize((mask_w, mask_h)) |
||||
cropped_boxes = [] |
||||
cropped_images = [] |
||||
not_crop = [] |
||||
filter_id = [] |
||||
# annotations, _ = filter_masks(annotations) |
||||
# filter_id = list(_) |
||||
for _, mask in enumerate(annotations): |
||||
if np.sum(mask['segmentation']) <= 100: |
||||
filter_id.append(_) |
||||
continue |
||||
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox |
||||
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片 |
||||
# cropped_boxes.append(segment_image(image,mask["segmentation"])) |
||||
cropped_images.append(bbox) # 保存裁剪的图片的bbox |
||||
|
||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations |
||||
|
||||
def box_prompt(self, bbox): |
||||
|
||||
assert (bbox[2] != 0 and bbox[3] != 0) |
||||
masks = self.results[0].masks.data |
||||
target_height = self.ori_img.shape[0] |
||||
target_width = self.ori_img.shape[1] |
||||
h = masks.shape[1] |
||||
w = masks.shape[2] |
||||
if h != target_height or w != target_width: |
||||
bbox = [ |
||||
int(bbox[0] * w / target_width), |
||||
int(bbox[1] * h / target_height), |
||||
int(bbox[2] * w / target_width), |
||||
int(bbox[3] * h / target_height), ] |
||||
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 |
||||
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 |
||||
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w |
||||
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h |
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32) |
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) |
||||
|
||||
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) |
||||
orig_masks_area = torch.sum(masks, dim=(1, 2)) |
||||
|
||||
union = bbox_area + orig_masks_area - masks_area |
||||
IoUs = masks_area / union |
||||
max_iou_index = torch.argmax(IoUs) |
||||
|
||||
return np.array([masks[max_iou_index].cpu().numpy()]) |
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy 处理 |
||||
|
||||
masks = self._format_results(self.results[0], 0) |
||||
target_height = self.ori_img.shape[0] |
||||
target_width = self.ori_img.shape[1] |
||||
h = masks[0]['segmentation'].shape[0] |
||||
w = masks[0]['segmentation'].shape[1] |
||||
if h != target_height or w != target_width: |
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] |
||||
onemask = np.zeros((h, w)) |
||||
for i, annotation in enumerate(masks): |
||||
if type(annotation) == dict: |
||||
mask = annotation['segmentation'] |
||||
else: |
||||
mask = annotation |
||||
for i, point in enumerate(points): |
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: |
||||
onemask += mask |
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: |
||||
onemask -= mask |
||||
onemask = onemask >= 1 |
||||
return np.array([onemask]) |
||||
|
||||
def text_prompt(self, text): |
||||
format_results = self._format_results(self.results[0], 0) |
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) |
||||
clip_model, preprocess = clip.load('ViT-B/32', device=self.device) |
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) |
||||
max_idx = scores.argsort() |
||||
max_idx = max_idx[-1] |
||||
max_idx += sum(np.array(filter_id) <= int(max_idx)) |
||||
return np.array([annotations[max_idx]['segmentation']]) |
||||
|
||||
def everything_prompt(self): |
||||
return self.results[0].masks.data |
@ -0,0 +1,63 @@ |
||||
import torch |
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): |
||||
'''Adjust bounding boxes to stick to image border if they are within a certain threshold. |
||||
Args: |
||||
boxes: (n, 4) |
||||
image_shape: (height, width) |
||||
threshold: pixel threshold |
||||
|
||||
Returns: |
||||
adjusted_boxes: adjusted bounding boxes |
||||
''' |
||||
|
||||
# Image dimensions |
||||
h, w = image_shape |
||||
|
||||
# Adjust boxes |
||||
boxes[:, 0] = torch.where(boxes[:, 0] < threshold, 0, boxes[:, 0]) # x1 |
||||
boxes[:, 1] = torch.where(boxes[:, 1] < threshold, 0, boxes[:, 1]) # y1 |
||||
boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, w, boxes[:, 2]) # x2 |
||||
boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, h, boxes[:, 3]) # y2 |
||||
|
||||
return boxes |
||||
|
||||
|
||||
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): |
||||
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. |
||||
Args: |
||||
box1: (4, ) |
||||
boxes: (n, 4) |
||||
|
||||
Returns: |
||||
high_iou_indices: Indices of boxes with IoU > thres |
||||
''' |
||||
boxes = adjust_bboxes_to_image_border(boxes, image_shape) |
||||
# obtain coordinates for intersections |
||||
x1 = torch.max(box1[0], boxes[:, 0]) |
||||
y1 = torch.max(box1[1], boxes[:, 1]) |
||||
x2 = torch.min(box1[2], boxes[:, 2]) |
||||
y2 = torch.min(box1[3], boxes[:, 3]) |
||||
|
||||
# compute the area of intersection |
||||
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) |
||||
|
||||
# compute the area of both individual boxes |
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) |
||||
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
||||
|
||||
# compute the area of union |
||||
union = box1_area + box2_area - intersection |
||||
|
||||
# compute the IoU |
||||
iou = intersection / union # Should be shape (n, ) |
||||
if raw_output: |
||||
if iou.numel() == 0: |
||||
return 0 |
||||
return iou |
||||
|
||||
# get indices of boxes with IoU > thres |
||||
high_iou_indices = torch.nonzero(iou > iou_thres).flatten() |
||||
|
||||
return high_iou_indices |
@ -0,0 +1,244 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from multiprocessing.pool import ThreadPool |
||||
from pathlib import Path |
||||
|
||||
import numpy as np |
||||
import torch |
||||
import torch.nn.functional as F |
||||
|
||||
from ultralytics.yolo.utils import LOGGER, NUM_THREADS, ops |
||||
from ultralytics.yolo.utils.checks import check_requirements |
||||
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou |
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images |
||||
from ultralytics.yolo.v8.detect import DetectionValidator |
||||
|
||||
|
||||
class FastSAMValidator(DetectionValidator): |
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): |
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" |
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks) |
||||
self.args.task = 'segment' |
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) |
||||
|
||||
def preprocess(self, batch): |
||||
"""Preprocesses batch by converting masks to float and sending to device.""" |
||||
batch = super().preprocess(batch) |
||||
batch['masks'] = batch['masks'].to(self.device).float() |
||||
return batch |
||||
|
||||
def init_metrics(self, model): |
||||
"""Initialize metrics and select mask processing function based on save_json flag.""" |
||||
super().init_metrics(model) |
||||
self.plot_masks = [] |
||||
if self.args.save_json: |
||||
check_requirements('pycocotools>=2.0.6') |
||||
self.process = ops.process_mask_upsample # more accurate |
||||
else: |
||||
self.process = ops.process_mask # faster |
||||
|
||||
def get_desc(self): |
||||
"""Return a formatted description of evaluation metrics.""" |
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', |
||||
'R', 'mAP50', 'mAP50-95)') |
||||
|
||||
def postprocess(self, preds): |
||||
"""Postprocesses YOLO predictions and returns output detections with proto.""" |
||||
p = ops.non_max_suppression(preds[0], |
||||
self.args.conf, |
||||
self.args.iou, |
||||
labels=self.lb, |
||||
multi_label=True, |
||||
agnostic=self.args.single_cls, |
||||
max_det=self.args.max_det, |
||||
nc=self.nc) |
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported |
||||
return p, proto |
||||
|
||||
def update_metrics(self, preds, batch): |
||||
"""Metrics.""" |
||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): |
||||
idx = batch['batch_idx'] == si |
||||
cls = batch['cls'][idx] |
||||
bbox = batch['bboxes'][idx] |
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions |
||||
shape = batch['ori_shape'][si] |
||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init |
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init |
||||
self.seen += 1 |
||||
|
||||
if npr == 0: |
||||
if nl: |
||||
self.stats.append((correct_bboxes, correct_masks, *torch.zeros( |
||||
(2, 0), device=self.device), cls.squeeze(-1))) |
||||
if self.args.plots: |
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1)) |
||||
continue |
||||
|
||||
# Masks |
||||
midx = [si] if self.args.overlap_mask else idx |
||||
gt_masks = batch['masks'][midx] |
||||
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:]) |
||||
|
||||
# Predictions |
||||
if self.args.single_cls: |
||||
pred[:, 5] = 0 |
||||
predn = pred.clone() |
||||
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape, |
||||
ratio_pad=batch['ratio_pad'][si]) # native-space pred |
||||
|
||||
# Evaluate |
||||
if nl: |
||||
height, width = batch['img'].shape[2:] |
||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor( |
||||
(width, height, width, height), device=self.device) # target boxes |
||||
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape, |
||||
ratio_pad=batch['ratio_pad'][si]) # native-space labels |
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels |
||||
correct_bboxes = self._process_batch(predn, labelsn) |
||||
# TODO: maybe remove these `self.` arguments as they already are member variable |
||||
correct_masks = self._process_batch(predn, |
||||
labelsn, |
||||
pred_masks, |
||||
gt_masks, |
||||
overlap=self.args.overlap_mask, |
||||
masks=True) |
||||
if self.args.plots: |
||||
self.confusion_matrix.process_batch(predn, labelsn) |
||||
|
||||
# Append correct_masks, correct_boxes, pconf, pcls, tcls |
||||
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1))) |
||||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) |
||||
if self.args.plots and self.batch_i < 3: |
||||
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot |
||||
|
||||
# Save |
||||
if self.args.save_json: |
||||
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), |
||||
shape, |
||||
ratio_pad=batch['ratio_pad'][si]) |
||||
self.pred_to_json(predn, batch['im_file'][si], pred_masks) |
||||
# if self.args.save_txt: |
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') |
||||
|
||||
def finalize_metrics(self, *args, **kwargs): |
||||
"""Sets speed and confusion matrix for evaluation metrics.""" |
||||
self.metrics.speed = self.speed |
||||
self.metrics.confusion_matrix = self.confusion_matrix |
||||
|
||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False): |
||||
""" |
||||
Return correct prediction matrix |
||||
Arguments: |
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class |
||||
labels (array[M, 5]), class, x1, y1, x2, y2 |
||||
Returns: |
||||
correct (array[N, 10]), for 10 IoU levels |
||||
""" |
||||
if masks: |
||||
if overlap: |
||||
nl = len(labels) |
||||
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 |
||||
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) |
||||
gt_masks = torch.where(gt_masks == index, 1.0, 0.0) |
||||
if gt_masks.shape[1:] != pred_masks.shape[1:]: |
||||
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0] |
||||
gt_masks = gt_masks.gt_(0.5) |
||||
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) |
||||
else: # boxes |
||||
iou = box_iou(labels[:, 1:], detections[:, :4]) |
||||
|
||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) |
||||
correct_class = labels[:, 0:1] == detections[:, 5] |
||||
for i in range(len(self.iouv)): |
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match |
||||
if x[0].shape[0]: |
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), |
||||
1).cpu().numpy() # [label, detect, iou] |
||||
if x[0].shape[0] > 1: |
||||
matches = matches[matches[:, 2].argsort()[::-1]] |
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]] |
||||
# matches = matches[matches[:, 2].argsort()[::-1]] |
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] |
||||
correct[matches[:, 1].astype(int), i] = True |
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device) |
||||
|
||||
def plot_val_samples(self, batch, ni): |
||||
"""Plots validation samples with bounding box labels.""" |
||||
plot_images(batch['img'], |
||||
batch['batch_idx'], |
||||
batch['cls'].squeeze(-1), |
||||
batch['bboxes'], |
||||
batch['masks'], |
||||
paths=batch['im_file'], |
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg', |
||||
names=self.names, |
||||
on_plot=self.on_plot) |
||||
|
||||
def plot_predictions(self, batch, preds, ni): |
||||
"""Plots batch predictions with masks and bounding boxes.""" |
||||
plot_images( |
||||
batch['img'], |
||||
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed |
||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, |
||||
paths=batch['im_file'], |
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg', |
||||
names=self.names, |
||||
on_plot=self.on_plot) # pred |
||||
self.plot_masks.clear() |
||||
|
||||
def pred_to_json(self, predn, filename, pred_masks): |
||||
"""Save one JSON result.""" |
||||
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} |
||||
from pycocotools.mask import encode # noqa |
||||
|
||||
def single_encode(x): |
||||
"""Encode predicted masks as RLE and append results to jdict.""" |
||||
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] |
||||
rle['counts'] = rle['counts'].decode('utf-8') |
||||
return rle |
||||
|
||||
stem = Path(filename).stem |
||||
image_id = int(stem) if stem.isnumeric() else stem |
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh |
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner |
||||
pred_masks = np.transpose(pred_masks, (2, 0, 1)) |
||||
with ThreadPool(NUM_THREADS) as pool: |
||||
rles = pool.map(single_encode, pred_masks) |
||||
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): |
||||
self.jdict.append({ |
||||
'image_id': image_id, |
||||
'category_id': self.class_map[int(p[5])], |
||||
'bbox': [round(x, 3) for x in b], |
||||
'score': round(p[4], 5), |
||||
'segmentation': rles[i]}) |
||||
|
||||
def eval_json(self, stats): |
||||
"""Return COCO-style object detection evaluation metrics.""" |
||||
if self.args.save_json and self.is_coco and len(self.jdict): |
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations |
||||
pred_json = self.save_dir / 'predictions.json' # predictions |
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') |
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb |
||||
check_requirements('pycocotools>=2.0.6') |
||||
from pycocotools.coco import COCO # noqa |
||||
from pycocotools.cocoeval import COCOeval # noqa |
||||
|
||||
for x in anno_json, pred_json: |
||||
assert x.is_file(), f'{x} file not found' |
||||
anno = COCO(str(anno_json)) # init annotations api |
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) |
||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): |
||||
if self.is_coco: |
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval |
||||
eval.evaluate() |
||||
eval.accumulate() |
||||
eval.summarize() |
||||
idx = i * 4 + 2 |
||||
stats[self.metrics.keys[idx + 1]], stats[ |
||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 |
||||
except Exception as e: |
||||
LOGGER.warning(f'pycocotools unable to run: {e}') |
||||
return stats |
Loading…
Reference in new issue