You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
322 lines
13 KiB
322 lines
13 KiB
1 year ago
|
import argparse
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
import onnxruntime as ort
|
||
|
|
||
|
from ultralytics.utils import ASSETS, yaml_load
|
||
|
from ultralytics.utils.checks import check_yaml
|
||
|
from ultralytics.utils.plotting import Colors
|
||
|
|
||
|
|
||
|
class YOLOv8Seg:
|
||
|
"""YOLOv8 segmentation model."""
|
||
|
|
||
|
def __init__(self, onnx_model):
|
||
|
"""
|
||
|
Initialization.
|
||
|
|
||
|
Args:
|
||
|
onnx_model (str): Path to the ONNX model.
|
||
|
"""
|
||
|
|
||
|
# Build Ort session
|
||
|
self.session = ort.InferenceSession(onnx_model,
|
||
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||
|
if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])
|
||
|
|
||
|
# Numpy dtype: support both FP32 and FP16 onnx model
|
||
|
self.ndtype = np.half if self.session.get_inputs()[0].type == 'tensor(float16)' else np.single
|
||
|
|
||
|
# Get model width and height(YOLOv8-seg only has one input)
|
||
|
self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:]
|
||
|
|
||
|
# Load COCO class names
|
||
|
self.classes = yaml_load(check_yaml('coco128.yaml'))['names']
|
||
|
|
||
|
# Create color palette
|
||
|
self.color_palette = Colors()
|
||
|
|
||
|
def __call__(self, im0, conf_threshold=0.4, iou_threshold=0.45, nm=32):
|
||
|
"""
|
||
|
The whole pipeline: pre-process -> inference -> post-process.
|
||
|
|
||
|
Args:
|
||
|
im0 (Numpy.ndarray): original input image.
|
||
|
conf_threshold (float): confidence threshold for filtering predictions.
|
||
|
iou_threshold (float): iou threshold for NMS.
|
||
|
nm (int): the number of masks.
|
||
|
|
||
|
Returns:
|
||
|
boxes (List): list of bounding boxes.
|
||
|
segments (List): list of segments.
|
||
|
masks (np.ndarray): [N, H, W], output masks.
|
||
|
"""
|
||
|
|
||
|
# Pre-process
|
||
|
im, ratio, (pad_w, pad_h) = self.preprocess(im0)
|
||
|
|
||
|
# Ort inference
|
||
|
preds = self.session.run(None, {self.session.get_inputs()[0].name: im})
|
||
|
|
||
|
# Post-process
|
||
|
boxes, segments, masks = self.postprocess(preds,
|
||
|
im0=im0,
|
||
|
ratio=ratio,
|
||
|
pad_w=pad_w,
|
||
|
pad_h=pad_h,
|
||
|
conf_threshold=conf_threshold,
|
||
|
iou_threshold=iou_threshold,
|
||
|
nm=nm)
|
||
|
return boxes, segments, masks
|
||
|
|
||
|
def preprocess(self, img):
|
||
|
"""
|
||
|
Pre-processes the input image.
|
||
|
|
||
|
Args:
|
||
|
img (Numpy.ndarray): image about to be processed.
|
||
|
|
||
|
Returns:
|
||
|
img_process (Numpy.ndarray): image preprocessed for inference.
|
||
|
ratio (tuple): width, height ratios in letterbox.
|
||
|
pad_w (float): width padding in letterbox.
|
||
|
pad_h (float): height padding in letterbox.
|
||
|
"""
|
||
|
|
||
|
# Resize and pad input image using letterbox() (Borrowed from Ultralytics)
|
||
|
shape = img.shape[:2] # original image shape
|
||
|
new_shape = (self.model_height, self.model_width)
|
||
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||
|
ratio = r, r
|
||
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||
|
pad_w, pad_h = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding
|
||
|
if shape[::-1] != new_unpad: # resize
|
||
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||
|
top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
|
||
|
left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
|
||
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||
|
|
||
|
# Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
|
||
|
img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0
|
||
|
img_process = img[None] if len(img.shape) == 3 else img
|
||
|
return img_process, ratio, (pad_w, pad_h)
|
||
|
|
||
|
def postprocess(self, preds, im0, ratio, pad_w, pad_h, conf_threshold, iou_threshold, nm=32):
|
||
|
"""
|
||
|
Post-process the prediction.
|
||
|
|
||
|
Args:
|
||
|
preds (Numpy.ndarray): predictions come from ort.session.run().
|
||
|
im0 (Numpy.ndarray): [h, w, c] original input image.
|
||
|
ratio (tuple): width, height ratios in letterbox.
|
||
|
pad_w (float): width padding in letterbox.
|
||
|
pad_h (float): height padding in letterbox.
|
||
|
conf_threshold (float): conf threshold.
|
||
|
iou_threshold (float): iou threshold.
|
||
|
nm (int): the number of masks.
|
||
|
|
||
|
Returns:
|
||
|
boxes (List): list of bounding boxes.
|
||
|
segments (List): list of segments.
|
||
|
masks (np.ndarray): [N, H, W], output masks.
|
||
|
"""
|
||
|
x, protos = preds[0], preds[1] # Two outputs: predictions and protos
|
||
|
|
||
|
# Transpose the first output: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
|
||
|
x = np.einsum('bcn->bnc', x)
|
||
|
|
||
|
# Predictions filtering by conf-threshold
|
||
|
x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
|
||
|
|
||
|
# Create a new matrix which merge these(box, score, cls, nm) into one
|
||
|
# For more details about `numpy.c_()`: https://numpy.org/doc/1.26/reference/generated/numpy.c_.html
|
||
|
x = np.c_[x[..., :4], np.amax(x[..., 4:-nm], axis=-1), np.argmax(x[..., 4:-nm], axis=-1), x[..., -nm:]]
|
||
|
|
||
|
# NMS filtering
|
||
|
x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)]
|
||
|
|
||
|
# Decode and return
|
||
|
if len(x) > 0:
|
||
|
|
||
|
# Bounding boxes format change: cxcywh -> xyxy
|
||
|
x[..., [0, 1]] -= x[..., [2, 3]] / 2
|
||
|
x[..., [2, 3]] += x[..., [0, 1]]
|
||
|
|
||
|
# Rescales bounding boxes from model shape(model_height, model_width) to the shape of original image
|
||
|
x[..., :4] -= [pad_w, pad_h, pad_w, pad_h]
|
||
|
x[..., :4] /= min(ratio)
|
||
|
|
||
|
# Bounding boxes boundary clamp
|
||
|
x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1])
|
||
|
x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0])
|
||
|
|
||
|
# Process masks
|
||
|
masks = self.process_mask(protos[0], x[:, 6:], x[:, :4], im0.shape)
|
||
|
|
||
|
# Masks -> Segments(contours)
|
||
|
segments = self.masks2segments(masks)
|
||
|
return x[..., :6], segments, masks # boxes, segments, masks
|
||
|
else:
|
||
|
return [], [], []
|
||
|
|
||
|
@staticmethod
|
||
|
def masks2segments(masks):
|
||
|
"""
|
||
|
It takes a list of masks(n,h,w) and returns a list of segments(n,xy) (Borrowed from
|
||
|
https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L750)
|
||
|
|
||
|
Args:
|
||
|
masks (numpy.ndarray): the output of the model, which is a tensor of shape (batch_size, 160, 160).
|
||
|
|
||
|
Returns:
|
||
|
segments (List): list of segment masks.
|
||
|
"""
|
||
|
segments = []
|
||
|
for x in masks.astype('uint8'):
|
||
|
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE
|
||
|
if c:
|
||
|
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||
|
else:
|
||
|
c = np.zeros((0, 2)) # no segments found
|
||
|
segments.append(c.astype('float32'))
|
||
|
return segments
|
||
|
|
||
|
@staticmethod
|
||
|
def crop_mask(masks, boxes):
|
||
|
"""
|
||
|
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box. (Borrowed from
|
||
|
https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L599)
|
||
|
|
||
|
Args:
|
||
|
masks (Numpy.ndarray): [n, h, w] tensor of masks.
|
||
|
boxes (Numpy.ndarray): [n, 4] tensor of bbox coordinates in relative point form.
|
||
|
|
||
|
Returns:
|
||
|
(Numpy.ndarray): The masks are being cropped to the bounding box.
|
||
|
"""
|
||
|
n, h, w = masks.shape
|
||
|
x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1)
|
||
|
r = np.arange(w, dtype=x1.dtype)[None, None, :]
|
||
|
c = np.arange(h, dtype=x1.dtype)[None, :, None]
|
||
|
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||
|
|
||
|
def process_mask(self, protos, masks_in, bboxes, im0_shape):
|
||
|
"""
|
||
|
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
|
||
|
but is slower. (Borrowed from https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L618)
|
||
|
|
||
|
Args:
|
||
|
protos (numpy.ndarray): [mask_dim, mask_h, mask_w].
|
||
|
masks_in (numpy.ndarray): [n, mask_dim], n is number of masks after nms.
|
||
|
bboxes (numpy.ndarray): bboxes re-scaled to original image shape.
|
||
|
im0_shape (tuple): the size of the input image (h,w,c).
|
||
|
|
||
|
Returns:
|
||
|
(numpy.ndarray): The upsampled masks.
|
||
|
"""
|
||
|
c, mh, mw = protos.shape
|
||
|
masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN
|
||
|
masks = np.ascontiguousarray(masks)
|
||
|
masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape
|
||
|
masks = np.einsum('HWN -> NHW', masks) # HWN -> NHW
|
||
|
masks = self.crop_mask(masks, bboxes)
|
||
|
return np.greater(masks, 0.5)
|
||
|
|
||
|
@staticmethod
|
||
|
def scale_mask(masks, im0_shape, ratio_pad=None):
|
||
|
"""
|
||
|
Takes a mask, and resizes it to the original image size. (Borrowed from
|
||
|
https://github.com/ultralytics/ultralytics/blob/465df3024f44fa97d4fad9986530d5a13cdabdca/ultralytics/utils/ops.py#L305)
|
||
|
|
||
|
Args:
|
||
|
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||
|
im0_shape (tuple): the original image shape.
|
||
|
ratio_pad (tuple): the ratio of the padding to the original image.
|
||
|
|
||
|
Returns:
|
||
|
masks (np.ndarray): The masks that are being returned.
|
||
|
"""
|
||
|
im1_shape = masks.shape[:2]
|
||
|
if ratio_pad is None: # calculate from im0_shape
|
||
|
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
||
|
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
||
|
else:
|
||
|
pad = ratio_pad[1]
|
||
|
|
||
|
# Calculate tlbr of mask
|
||
|
top, left = int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)) # y, x
|
||
|
bottom, right = int(round(im1_shape[0] - pad[1] + 0.1)), int(round(im1_shape[1] - pad[0] + 0.1))
|
||
|
if len(masks.shape) < 2:
|
||
|
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
||
|
masks = masks[top:bottom, left:right]
|
||
|
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]),
|
||
|
interpolation=cv2.INTER_LINEAR) # INTER_CUBIC would be better
|
||
|
if len(masks.shape) == 2:
|
||
|
masks = masks[:, :, None]
|
||
|
return masks
|
||
|
|
||
|
def draw_and_visualize(self, im, bboxes, segments, vis=False, save=True):
|
||
|
"""
|
||
|
Draw and visualize results.
|
||
|
|
||
|
Args:
|
||
|
im (np.ndarray): original image, shape [h, w, c].
|
||
|
bboxes (numpy.ndarray): [n, 4], n is number of bboxes.
|
||
|
segments (List): list of segment masks.
|
||
|
vis (bool): imshow using OpenCV.
|
||
|
save (bool): save image annotated.
|
||
|
|
||
|
Returns:
|
||
|
None
|
||
|
"""
|
||
|
|
||
|
# Draw rectangles and polygons
|
||
|
im_canvas = im.copy()
|
||
|
for (*box, conf, cls_), segment in zip(bboxes, segments):
|
||
|
# draw contour and fill mask
|
||
|
cv2.polylines(im, np.int32([segment]), True, (255, 255, 255), 2) # white borderline
|
||
|
cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True))
|
||
|
|
||
|
# draw bbox rectangle
|
||
|
cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
|
||
|
self.color_palette(int(cls_), bgr=True), 1, cv2.LINE_AA)
|
||
|
cv2.putText(im, f'{self.classes[cls_]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)),
|
||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.color_palette(int(cls_), bgr=True), 2, cv2.LINE_AA)
|
||
|
|
||
|
# Mix image
|
||
|
im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0)
|
||
|
|
||
|
# Show image
|
||
|
if vis:
|
||
|
cv2.imshow('demo', im)
|
||
|
cv2.waitKey(0)
|
||
|
cv2.destroyAllWindows()
|
||
|
|
||
|
# Save image
|
||
|
if save:
|
||
|
cv2.imwrite('demo.jpg', im)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# Create an argument parser to handle command-line arguments
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--model', type=str, required=True, help='Path to ONNX model')
|
||
|
parser.add_argument('--source', type=str, default=str(ASSETS / 'bus.jpg'), help='Path to input image')
|
||
|
parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
|
||
|
parser.add_argument('--iou', type=float, default=0.45, help='NMS IoU threshold')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# Build model
|
||
|
model = YOLOv8Seg(args.model)
|
||
|
|
||
|
# Read image by OpenCV
|
||
|
img = cv2.imread(args.source)
|
||
|
|
||
|
# Inference
|
||
|
boxes, segments, _ = model(img, conf_threshold=args.conf, iou_threshold=args.iou)
|
||
|
|
||
|
# Draw bboxes and polygons
|
||
|
if len(boxes) > 0:
|
||
|
model.draw_and_visualize(img, boxes, segments, vis=False, save=True)
|