|
|
|
@ -1,25 +1,39 @@ |
|
|
|
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license |
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
|
from typing import List, Tuple, Union |
|
|
|
|
|
|
|
|
|
import cv2 |
|
|
|
|
import numpy as np |
|
|
|
|
import onnxruntime as ort |
|
|
|
|
import torch |
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
import ultralytics.utils.ops as ops |
|
|
|
|
from ultralytics.engine.results import Results |
|
|
|
|
from ultralytics.utils import ASSETS, yaml_load |
|
|
|
|
from ultralytics.utils.checks import check_yaml |
|
|
|
|
from ultralytics.utils.plotting import Colors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class YOLOv8Seg: |
|
|
|
|
"""YOLOv8 segmentation model.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, onnx_model): |
|
|
|
|
def __init__(self, onnx_model, conf_threshold=0.4): |
|
|
|
|
""" |
|
|
|
|
Initialization. |
|
|
|
|
Initializes the object detection model using an ONNX model. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
onnx_model (str): Path to the ONNX model. |
|
|
|
|
onnx_model (str): Path to the ONNX model file. |
|
|
|
|
conf_threshold (float, optional): Confidence threshold for detections. Defaults to 0.4. |
|
|
|
|
|
|
|
|
|
Attributes: |
|
|
|
|
session (ort.InferenceSession): ONNX Runtime session for running inference. |
|
|
|
|
ndtype (numpy.dtype): Data type for model input (FP16 or FP32). |
|
|
|
|
model_height (int): Height of the model's input image. |
|
|
|
|
model_width (int): Width of the model's input image. |
|
|
|
|
classes (list): List of class names from the COCO dataset. |
|
|
|
|
device (str): Specifies whether inference runs on CPU or GPU. |
|
|
|
|
conf_threshold (float): Confidence threshold for filtering detections. |
|
|
|
|
""" |
|
|
|
|
# Build Ort session |
|
|
|
|
self.session = ort.InferenceSession( |
|
|
|
@ -38,281 +52,190 @@ class YOLOv8Seg: |
|
|
|
|
# Load COCO class names |
|
|
|
|
self.classes = yaml_load(check_yaml("coco8.yaml"))["names"] |
|
|
|
|
|
|
|
|
|
# Create color palette |
|
|
|
|
self.color_palette = Colors() |
|
|
|
|
# Device |
|
|
|
|
self.device = "cuda:0" if ort.get_device().lower() == "gpu" else "cpu" |
|
|
|
|
|
|
|
|
|
def __call__(self, im0, conf_threshold=0.4, iou_threshold=0.45, nm=32): |
|
|
|
|
# Confidence |
|
|
|
|
self.conf_threshold = conf_threshold |
|
|
|
|
|
|
|
|
|
def __call__(self, im0): |
|
|
|
|
""" |
|
|
|
|
The whole pipeline: pre-process -> inference -> post-process. |
|
|
|
|
Runs inference on the input image using the ONNX model. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
im0 (Numpy.ndarray): original input image. |
|
|
|
|
conf_threshold (float): confidence threshold for filtering predictions. |
|
|
|
|
iou_threshold (float): iou threshold for NMS. |
|
|
|
|
nm (int): the number of masks. |
|
|
|
|
im0 (numpy.ndarray): The original input image in BGR format. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
boxes (List): list of bounding boxes. |
|
|
|
|
segments (List): list of segments. |
|
|
|
|
masks (np.ndarray): [N, H, W], output masks. |
|
|
|
|
list: Processed detection results after post-processing. |
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
|
>>> detector = Model("yolov8.onnx") |
|
|
|
|
>>> results = detector(image) # Runs inference and returns detections. |
|
|
|
|
""" |
|
|
|
|
# Pre-process |
|
|
|
|
im, ratio, (pad_w, pad_h) = self.preprocess(im0) |
|
|
|
|
processed_image = self.preprocess(im0) |
|
|
|
|
|
|
|
|
|
# Ort inference |
|
|
|
|
preds = self.session.run(None, {self.session.get_inputs()[0].name: im}) |
|
|
|
|
predictions = self.session.run(None, {self.session.get_inputs()[0].name: processed_image}) |
|
|
|
|
|
|
|
|
|
# Post-process |
|
|
|
|
boxes, segments, masks = self.postprocess( |
|
|
|
|
preds, |
|
|
|
|
im0=im0, |
|
|
|
|
ratio=ratio, |
|
|
|
|
pad_w=pad_w, |
|
|
|
|
pad_h=pad_h, |
|
|
|
|
conf_threshold=conf_threshold, |
|
|
|
|
iou_threshold=iou_threshold, |
|
|
|
|
nm=nm, |
|
|
|
|
) |
|
|
|
|
return boxes, segments, masks |
|
|
|
|
results = self.postprocess(im0, processed_image, predictions) |
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
def preprocess(self, img): |
|
|
|
|
def preprocess(self, image, new_shape: Union[Tuple, List] = (640, 640)): |
|
|
|
|
""" |
|
|
|
|
Pre-processes the input image. |
|
|
|
|
Preprocesses the input image before feeding it into the model. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
img (Numpy.ndarray): image about to be processed. |
|
|
|
|
image (np.ndarray): The input image in BGR format. |
|
|
|
|
new_shape (Tuple or List, optional): The target shape for resizing. Defaults to (640, 640). |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
img_process (Numpy.ndarray): image preprocessed for inference. |
|
|
|
|
ratio (tuple): width, height ratios in letterbox. |
|
|
|
|
pad_w (float): width padding in letterbox. |
|
|
|
|
pad_h (float): height padding in letterbox. |
|
|
|
|
np.ndarray: Preprocessed image ready for model inference. |
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
|
>>> processed_img = model.preprocess(image) |
|
|
|
|
""" |
|
|
|
|
# Resize and pad input image using letterbox() (Borrowed from Ultralytics) |
|
|
|
|
shape = img.shape[:2] # original image shape |
|
|
|
|
new_shape = (self.model_height, self.model_width) |
|
|
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) |
|
|
|
|
ratio = r, r |
|
|
|
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) |
|
|
|
|
pad_w, pad_h = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding |
|
|
|
|
if shape[::-1] != new_unpad: # resize |
|
|
|
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) |
|
|
|
|
top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1)) |
|
|
|
|
left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1)) |
|
|
|
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) |
|
|
|
|
|
|
|
|
|
# Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional) |
|
|
|
|
img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0 |
|
|
|
|
img_process = img[None] if len(img.shape) == 3 else img |
|
|
|
|
return img_process, ratio, (pad_w, pad_h) |
|
|
|
|
|
|
|
|
|
def postprocess(self, preds, im0, ratio, pad_w, pad_h, conf_threshold, iou_threshold, nm=32): |
|
|
|
|
image, _, _ = self.__resize_and_pad_image(image=image, new_shape=new_shape) |
|
|
|
|
image = self.__reshape_image(image=image) |
|
|
|
|
processed_image = image[None] if len(image.shape) == 3 else image |
|
|
|
|
return processed_image |
|
|
|
|
|
|
|
|
|
def __reshape_image(self, image: np.ndarray) -> np.ndarray: |
|
|
|
|
""" |
|
|
|
|
Post-process the prediction. |
|
|
|
|
Reshapes the image by changing its layout and normalizing pixel values. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
preds (Numpy.ndarray): predictions come from ort.session.run(). |
|
|
|
|
im0 (Numpy.ndarray): [h, w, c] original input image. |
|
|
|
|
ratio (tuple): width, height ratios in letterbox. |
|
|
|
|
pad_w (float): width padding in letterbox. |
|
|
|
|
pad_h (float): height padding in letterbox. |
|
|
|
|
conf_threshold (float): conf threshold. |
|
|
|
|
iou_threshold (float): iou threshold. |
|
|
|
|
nm (int): the number of masks. |
|
|
|
|
image (np.ndarray): The image to be reshaped. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
boxes (List): list of bounding boxes. |
|
|
|
|
segments (List): list of segments. |
|
|
|
|
masks (np.ndarray): [N, H, W], output masks. |
|
|
|
|
np.ndarray: Reshaped and normalized image. |
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
|
>>> reshaped_img = model.__reshape_image(image) |
|
|
|
|
""" |
|
|
|
|
x, protos = preds[0], preds[1] # Two outputs: predictions and protos |
|
|
|
|
image = image.transpose([2, 0, 1]) |
|
|
|
|
image = image[np.newaxis, ...] |
|
|
|
|
image = np.ascontiguousarray(image).astype(np.float32) / 255 |
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
def __resize_and_pad_image( |
|
|
|
|
self, image=np.ndarray, new_shape: Union[Tuple, List] = (640, 640), color: Union[Tuple, List] = (114, 114, 114) |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Resizes and pads the input image while maintaining the aspect ratio. |
|
|
|
|
|
|
|
|
|
# Transpose dim 1: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm) |
|
|
|
|
x = np.einsum("bcn->bnc", x) |
|
|
|
|
Args: |
|
|
|
|
image (np.ndarray): The input image. |
|
|
|
|
new_shape (Tuple or List, optional): Target shape (width, height). Defaults to (640, 640). |
|
|
|
|
color (Tuple or List, optional): Padding color. Defaults to (114, 114, 114). |
|
|
|
|
|
|
|
|
|
# Predictions filtering by conf-threshold |
|
|
|
|
x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold] |
|
|
|
|
Returns: |
|
|
|
|
Tuple[np.ndarray, float, float]: The resized image along with padding values. |
|
|
|
|
|
|
|
|
|
# Create a new matrix which merge these(box, score, cls, nm) into one |
|
|
|
|
# For more details about `numpy.c_()`: https://numpy.org/doc/1.26/reference/generated/numpy.c_.html |
|
|
|
|
x = np.c_[x[..., :4], np.amax(x[..., 4:-nm], axis=-1), np.argmax(x[..., 4:-nm], axis=-1), x[..., -nm:]] |
|
|
|
|
Example: |
|
|
|
|
>>> resized_img, dw, dh = model.__resize_and_pad_image(image) |
|
|
|
|
""" |
|
|
|
|
shape = image.shape[:2] # original image shape |
|
|
|
|
|
|
|
|
|
# NMS filtering |
|
|
|
|
x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)] |
|
|
|
|
if isinstance(new_shape, int): |
|
|
|
|
new_shape = (new_shape, new_shape) |
|
|
|
|
|
|
|
|
|
# Decode and return |
|
|
|
|
if len(x) > 0: |
|
|
|
|
# Bounding boxes format change: cxcywh -> xyxy |
|
|
|
|
x[..., [0, 1]] -= x[..., [2, 3]] / 2 |
|
|
|
|
x[..., [2, 3]] += x[..., [0, 1]] |
|
|
|
|
# Scale ratio (new / old) |
|
|
|
|
ratio = min(new_shape[0] / shape[1], new_shape[1] / shape[0]) |
|
|
|
|
|
|
|
|
|
# Rescales bounding boxes from model shape(model_height, model_width) to the shape of original image |
|
|
|
|
x[..., :4] -= [pad_w, pad_h, pad_w, pad_h] |
|
|
|
|
x[..., :4] /= min(ratio) |
|
|
|
|
new_unpad = int(round(shape[1] * ratio)), int(round(shape[0] * ratio)) |
|
|
|
|
delta_width, delta_height = new_shape[0] - new_unpad[0], new_shape[1] - new_unpad[1] |
|
|
|
|
|
|
|
|
|
# Bounding boxes boundary clamp |
|
|
|
|
x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1]) |
|
|
|
|
x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0]) |
|
|
|
|
# Divide padding into 2 sides |
|
|
|
|
delta_width /= 2 |
|
|
|
|
delta_height /= 2 |
|
|
|
|
|
|
|
|
|
# Process masks |
|
|
|
|
masks = self.process_mask(protos[0], x[:, 6:], x[:, :4], im0.shape) |
|
|
|
|
image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] == new_unpad else image |
|
|
|
|
|
|
|
|
|
# Masks -> Segments(contours) |
|
|
|
|
segments = self.masks2segments(masks) |
|
|
|
|
return x[..., :6], segments, masks # boxes, segments, masks |
|
|
|
|
else: |
|
|
|
|
return [], [], [] |
|
|
|
|
top, bottom = int(round(delta_height - 0.1)), int(round(delta_height + 0.1)) |
|
|
|
|
left, right = int(round(delta_width - 0.1)), int(round(delta_width + 0.1)) |
|
|
|
|
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) |
|
|
|
|
return image, delta_width, delta_height |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def masks2segments(masks): |
|
|
|
|
def postprocess(self, image, processed_image, predictions): |
|
|
|
|
""" |
|
|
|
|
Takes a list of masks(n,h,w) and returns a list of segments(n,xy), from |
|
|
|
|
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. |
|
|
|
|
Post-processes model predictions to extract meaningful results. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
masks (numpy.ndarray): the output of the model, which is a tensor of shape (batch_size, 160, 160). |
|
|
|
|
image (np.ndarray): The original input image. |
|
|
|
|
processed_image (np.ndarray): The preprocessed image used for inference. |
|
|
|
|
predictions (list): Model output predictions. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
segments (List): list of segment masks. |
|
|
|
|
""" |
|
|
|
|
segments = [] |
|
|
|
|
for x in masks.astype("uint8"): |
|
|
|
|
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE |
|
|
|
|
if c: |
|
|
|
|
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) |
|
|
|
|
else: |
|
|
|
|
c = np.zeros((0, 2)) # no segments found |
|
|
|
|
segments.append(c.astype("float32")) |
|
|
|
|
return segments |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def crop_mask(masks, boxes): |
|
|
|
|
""" |
|
|
|
|
Takes a mask and a bounding box, and returns a mask that is cropped to the bounding box, from |
|
|
|
|
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. |
|
|
|
|
list: Processed detection results. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
masks (Numpy.ndarray): [n, h, w] tensor of masks. |
|
|
|
|
boxes (Numpy.ndarray): [n, 4] tensor of bbox coordinates in relative point form. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(Numpy.ndarray): The masks are being cropped to the bounding box. |
|
|
|
|
Example: |
|
|
|
|
>>> results = model.postprocess(image, processed_image, predictions) |
|
|
|
|
""" |
|
|
|
|
n, h, w = masks.shape |
|
|
|
|
x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1) |
|
|
|
|
r = np.arange(w, dtype=x1.dtype)[None, None, :] |
|
|
|
|
c = np.arange(h, dtype=x1.dtype)[None, :, None] |
|
|
|
|
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) |
|
|
|
|
torch_tensor_predictions = [torch.from_numpy(output) for output in predictions] |
|
|
|
|
torch_tensor_boxes_confidence_category_predictions = torch_tensor_predictions[0] |
|
|
|
|
masks_predictions_tensor = torch_tensor_predictions[1].to(self.device) |
|
|
|
|
|
|
|
|
|
nms_boxes_confidence_category_predictions_tensor = ops.non_max_suppression( |
|
|
|
|
torch_tensor_boxes_confidence_category_predictions, |
|
|
|
|
conf_thres=self.conf_threshold, |
|
|
|
|
nc=len(self.classes), |
|
|
|
|
agnostic=False, |
|
|
|
|
max_det=100, |
|
|
|
|
max_time_img=0.001, |
|
|
|
|
max_nms=1000, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def process_mask(self, protos, masks_in, bboxes, im0_shape): |
|
|
|
|
results = [] |
|
|
|
|
for idx, predictions in enumerate(nms_boxes_confidence_category_predictions_tensor): |
|
|
|
|
predictions = predictions.to(self.device) |
|
|
|
|
masks = self.__process_mask( |
|
|
|
|
masks_predictions_tensor[idx], |
|
|
|
|
predictions[:, 6:], |
|
|
|
|
predictions[:, :4], |
|
|
|
|
processed_image.shape[2:], |
|
|
|
|
upsample=True, |
|
|
|
|
) # HWC |
|
|
|
|
predictions[:, :4] = ops.scale_boxes(processed_image.shape[2:], predictions[:, :4], image.shape) |
|
|
|
|
results.append(Results(image, path="", names=self.classes, boxes=predictions[:, :6], masks=masks)) |
|
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
def __process_mask(self, protos, masks_in, bboxes, shape, upsample=False): |
|
|
|
|
""" |
|
|
|
|
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher |
|
|
|
|
quality but is slower, from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. |
|
|
|
|
Processes segmentation masks from the model output. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
protos (numpy.ndarray): [mask_dim, mask_h, mask_w]. |
|
|
|
|
masks_in (numpy.ndarray): [n, mask_dim], n is number of masks after nms. |
|
|
|
|
bboxes (numpy.ndarray): bboxes re-scaled to original image shape. |
|
|
|
|
im0_shape (tuple): the size of the input image (h,w,c). |
|
|
|
|
protos (torch.Tensor): The prototype mask predictions from the model. |
|
|
|
|
masks_in (torch.Tensor): The raw mask predictions. |
|
|
|
|
bboxes (torch.Tensor): Bounding boxes for the detected objects. |
|
|
|
|
shape (Tuple): Target shape for mask resizing. |
|
|
|
|
upsample (bool, optional): Whether to upscale masks to match the original image size. Defaults to False. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(numpy.ndarray): The upsampled masks. |
|
|
|
|
""" |
|
|
|
|
c, mh, mw = protos.shape |
|
|
|
|
masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN |
|
|
|
|
masks = np.ascontiguousarray(masks) |
|
|
|
|
masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape |
|
|
|
|
masks = np.einsum("HWN -> NHW", masks) # HWN -> NHW |
|
|
|
|
masks = self.crop_mask(masks, bboxes) |
|
|
|
|
return np.greater(masks, 0.5) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def scale_mask(masks, im0_shape, ratio_pad=None): |
|
|
|
|
""" |
|
|
|
|
Takes a mask, and resizes it to the original image size, from |
|
|
|
|
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. |
|
|
|
|
torch.Tensor: Processed binary masks. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3]. |
|
|
|
|
im0_shape (tuple): the original image shape. |
|
|
|
|
ratio_pad (tuple): the ratio of the padding to the original image. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
masks (np.ndarray): The masks that are being returned. |
|
|
|
|
""" |
|
|
|
|
im1_shape = masks.shape[:2] |
|
|
|
|
if ratio_pad is None: # calculate from im0_shape |
|
|
|
|
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new |
|
|
|
|
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding |
|
|
|
|
else: |
|
|
|
|
pad = ratio_pad[1] |
|
|
|
|
|
|
|
|
|
# Calculate tlbr of mask |
|
|
|
|
top, left = int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)) # y, x |
|
|
|
|
bottom, right = int(round(im1_shape[0] - pad[1] + 0.1)), int(round(im1_shape[1] - pad[0] + 0.1)) |
|
|
|
|
if len(masks.shape) < 2: |
|
|
|
|
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') |
|
|
|
|
masks = masks[top:bottom, left:right] |
|
|
|
|
masks = cv2.resize( |
|
|
|
|
masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR |
|
|
|
|
) # INTER_CUBIC would be better |
|
|
|
|
if len(masks.shape) == 2: |
|
|
|
|
masks = masks[:, :, None] |
|
|
|
|
return masks |
|
|
|
|
|
|
|
|
|
def draw_and_visualize(self, im, bboxes, segments, vis=False, save=True): |
|
|
|
|
Example: |
|
|
|
|
>>> masks = model.__process_mask(protos, masks_in, bboxes, shape, upsample=True) |
|
|
|
|
""" |
|
|
|
|
Draw and visualize results. |
|
|
|
|
c, mh, mw = protos.shape # CHW |
|
|
|
|
ih, iw = shape |
|
|
|
|
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW |
|
|
|
|
width_ratio = mw / iw |
|
|
|
|
height_ratio = mh / ih |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
im (np.ndarray): original image, shape [h, w, c]. |
|
|
|
|
bboxes (numpy.ndarray): [n, 4], n is number of bboxes. |
|
|
|
|
segments (List): list of segment masks. |
|
|
|
|
vis (bool): imshow using OpenCV. |
|
|
|
|
save (bool): save image annotated. |
|
|
|
|
downsampled_bboxes = bboxes.clone() |
|
|
|
|
downsampled_bboxes[:, 0] *= width_ratio |
|
|
|
|
downsampled_bboxes[:, 2] *= width_ratio |
|
|
|
|
downsampled_bboxes[:, 3] *= height_ratio |
|
|
|
|
downsampled_bboxes[:, 1] *= height_ratio |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
None |
|
|
|
|
""" |
|
|
|
|
# Draw rectangles and polygons |
|
|
|
|
im_canvas = im.copy() |
|
|
|
|
for (*box, conf, cls_), segment in zip(bboxes, segments): |
|
|
|
|
# draw contour and fill mask |
|
|
|
|
cv2.polylines(im, np.int32([segment]), True, (255, 255, 255), 2) # white borderline |
|
|
|
|
cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True)) |
|
|
|
|
|
|
|
|
|
# draw bbox rectangle |
|
|
|
|
cv2.rectangle( |
|
|
|
|
im, |
|
|
|
|
(int(box[0]), int(box[1])), |
|
|
|
|
(int(box[2]), int(box[3])), |
|
|
|
|
self.color_palette(int(cls_), bgr=True), |
|
|
|
|
1, |
|
|
|
|
cv2.LINE_AA, |
|
|
|
|
) |
|
|
|
|
cv2.putText( |
|
|
|
|
im, |
|
|
|
|
f"{self.classes[cls_]}: {conf:.3f}", |
|
|
|
|
(int(box[0]), int(box[1] - 9)), |
|
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
|
|
0.7, |
|
|
|
|
self.color_palette(int(cls_), bgr=True), |
|
|
|
|
2, |
|
|
|
|
cv2.LINE_AA, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Mix image |
|
|
|
|
im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0) |
|
|
|
|
|
|
|
|
|
# Show image |
|
|
|
|
if vis: |
|
|
|
|
cv2.imshow("demo", im) |
|
|
|
|
cv2.waitKey(0) |
|
|
|
|
cv2.destroyAllWindows() |
|
|
|
|
|
|
|
|
|
# Save image |
|
|
|
|
if save: |
|
|
|
|
cv2.imwrite("demo.jpg", im) |
|
|
|
|
masks = ops.crop_mask(masks, downsampled_bboxes) # CHW |
|
|
|
|
if upsample: |
|
|
|
|
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW |
|
|
|
|
return masks.gt_(0.5).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
@ -321,18 +244,18 @@ if __name__ == "__main__": |
|
|
|
|
parser.add_argument("--model", type=str, required=True, help="Path to ONNX model") |
|
|
|
|
parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image") |
|
|
|
|
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") |
|
|
|
|
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold") |
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
# Build model |
|
|
|
|
model = YOLOv8Seg(args.model) |
|
|
|
|
model = YOLOv8Seg(args.model, args.conf) |
|
|
|
|
|
|
|
|
|
# Read image by OpenCV |
|
|
|
|
img = cv2.imread(args.source) |
|
|
|
|
img = cv2.resize(img, (640, 640)) # Can be changed based on your models expected size |
|
|
|
|
|
|
|
|
|
# Inference |
|
|
|
|
boxes, segments, _ = model(img, conf_threshold=args.conf, iou_threshold=args.iou) |
|
|
|
|
results = model(img) |
|
|
|
|
|
|
|
|
|
# Draw bboxes and polygons |
|
|
|
|
if len(boxes) > 0: |
|
|
|
|
model.draw_and_visualize(img, boxes, segments, vis=False, save=True) |
|
|
|
|
cv2.imshow("Segmented Image", results[0].plot()) |
|
|
|
|
cv2.waitKey(0) |
|
|
|
|
cv2.destroyAllWindows() |
|
|
|
|