Support YOLOv8 seg model convert onnx and tensorrt

pull/6/head
triple-Mu 2 years ago
parent 15ca9b2cbc
commit a754dc3492
  1. 77
      export_seg.py
  2. 128
      infer.py
  3. 41
      models/common.py
  4. 9
      models/engine.py

@ -0,0 +1,77 @@
import argparse
from io import BytesIO
import onnx
import torch
from ultralytics import YOLO
from models.common import optim
try:
import onnxsim
except ImportError:
onnxsim = None
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-w',
'--weights',
type=str,
required=True,
help='PyTorch yolov8 weights')
parser.add_argument('--opset',
type=int,
default=11,
help='ONNX opset version')
parser.add_argument('--sim',
action='store_true',
help='simplify onnx model')
parser.add_argument('--input-shape',
nargs='+',
type=int,
default=[1, 3, 640, 640],
help='Model input shape only for api builder')
parser.add_argument('--device',
type=str,
default='cpu',
help='Export ONNX device')
args = parser.parse_args()
assert len(args.input_shape) == 4
return args
def main(args):
YOLOv8 = YOLO(args.weights)
model = YOLOv8.model.fuse().eval()
for m in model.modules():
optim(m)
m.to(args.device)
model.to(args.device)
fake_input = torch.randn(args.input_shape).to(args.device)
for _ in range(2):
model(fake_input)
save_path = args.weights.replace('.pt', '.onnx')
with BytesIO() as f:
torch.onnx.export(
model,
fake_input,
f,
opset_version=args.opset,
input_names=['images'],
output_names=['bboxes', 'scores', 'labels', 'maskconf', 'proto'])
f.seek(0)
onnx_model = onnx.load(f)
onnx.checker.check_model(onnx_model)
if args.sim:
try:
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplifier failure: {e}')
onnx.save(onnx_model, save_path)
print(f'ONNX export success, saved as {save_path}')
if __name__ == '__main__':
main(parse_args())

@ -3,10 +3,15 @@ import argparse
import os
import random
from pathlib import Path
from typing import Any, List, Tuple, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from numpy import ndarray
from torch import Tensor
from torchvision.ops import batched_nms
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
@ -34,8 +39,24 @@ COLORS = {
for i, cls in enumerate(CLASSES)
}
# the same as yolov8
MASK_COLORS = np.array([(255, 56, 56), (255, 157, 151), (255, 112, 31),
(255, 178, 29), (207, 210, 49), (72, 249, 10),
(146, 204, 23), (61, 219, 134), (26, 147, 52),
(0, 212, 187), (44, 153, 168), (0, 194, 255),
(52, 69, 147), (100, 115, 255), (0, 24, 236),
(132, 56, 255), (82, 0, 133), (203, 56, 255),
(255, 149, 200), (255, 55, 199)],
dtype=np.float32) / 255.
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
ALPHA = 0.5
def letterbox(
im: ndarray,
new_shape: Union[Tuple, List] = (640, 640),
color: Union[Tuple, List] = (114, 114, 114)
) -> Tuple[ndarray, float, Tuple[float, float]]:
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
@ -63,21 +84,27 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
right,
cv2.BORDER_CONSTANT,
value=color) # add border
return im, np.array([r, r, r, r],
dtype=np.float32), np.array([dw, dh, dw, dh],
dtype=np.float32)
return im, r, (dw, dh)
def blob(im):
im = im.transpose(2, 0, 1)
def blob(im: ndarray) -> Tuple[ndarray, ndarray]:
seg = im.astype(np.float32) / 255
im = im.transpose([2, 0, 1])
im = im[np.newaxis, ...]
im = np.ascontiguousarray(im).astype(np.float32) / 255
return im
return im, seg
def main(args):
device = torch.device(args.device)
Engine = TRTModule(args.engine, device)
H, W = Engine.inp_info[0].shape[-2:]
# set desired output names order
if args.seg:
Engine.set_desired(['bboxes', 'scores', 'labels', 'maskconf', 'proto'])
else:
Engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels'])
images_path = Path(args.imgs)
assert images_path.exists()
@ -98,18 +125,31 @@ def main(args):
save_image = save_path / image.name
bgr = cv2.imread(str(image))
draw = bgr.copy()
bgr, ratio, dwdh = letterbox(bgr)
bgr, ratio, dwdh = letterbox(bgr, (W, H))
dw, dh = int(dwdh[0]), int(dwdh[1])
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
tensor = blob(rgb)
ratio = torch.asarray(ratio, dtype=torch.float32, device=device)
dwdh = torch.asarray(dwdh, dtype=torch.float32, device=device)
tensor, seg_img = blob(rgb)
dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device)
tensor = torch.asarray(tensor, device=device)
num_dets, bboxes, scores, labels = Engine(tensor)
bboxes = bboxes[0, :num_dets.item()]
scores = scores[0, :num_dets.item()]
labels = labels[0, :num_dets.item()]
data = Engine(tensor)
if args.seg:
seg_img = torch.asarray(seg_img[dh:H - dh, dw:W - dw, :],
device=device)
bboxes, scores, labels, masks = seg_postprocess(
data, bgr.shape[:2], args.conf_thres, args.iou_thres)
mask, mask_color = [m[:, dh:H - dh, dw:W - dw, :] for m in masks]
inv_alph_masks = (1 - mask * 0.5).cumprod(0)
mcs = (mask_color * inv_alph_masks).sum(0) * 2
seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
draw = cv2.resize(seg_img.cpu().numpy().astype(np.uint8),
draw.shape[:2][::-1])
else:
bboxes, scores, labels, masks = det_postprocess(data)
bboxes -= dwdh
bboxes /= ratio
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.round().int().tolist()
cls_id = int(label)
@ -128,6 +168,55 @@ def main(args):
cv2.imwrite(str(save_image), draw)
def crop_mask(masks: Tensor, bboxes: Tensor) -> Tensor:
n, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(bboxes[:, :, None], 4, 1) # x1 shape(1,1,n)
r = torch.arange(w, device=masks.device,
dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
c = torch.arange(h, device=masks.device,
dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def seg_postprocess(
data: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
shape: Union[Tuple, List],
conf_thres: float = 0.25,
iou_thres: float = 0.65) -> Tuple[Tensor, Tensor, Tensor, List]:
assert len(data) == 5
h, w = shape[0] // 4, shape[1] // 4 # 4x downsampling
bboxes, scores, labels, maskconf, proto = (i[0] for i in data)
select = scores > conf_thres
bboxes, scores, labels, maskconf = bboxes[select], scores[select], labels[
select], maskconf[select]
idx = batched_nms(bboxes, scores, labels, iou_thres)
bboxes, scores, labels, maskconf = bboxes[idx], scores[idx], labels[
idx], maskconf[idx]
masks = (maskconf @ proto).view(-1, h, w)
masks = crop_mask(masks, bboxes / 4.)
masks = F.interpolate(masks[None],
shape,
mode='bilinear',
align_corners=False)[0]
masks = masks.gt_(0.5)[..., None]
cidx = (labels % len(MASK_COLORS)).cpu().numpy()
mask_color = torch.tensor(MASK_COLORS[cidx].reshape(-1, 1, 1,
3)).to(bboxes) * ALPHA
out = [masks, masks @ mask_color]
return bboxes, scores, labels, out
def det_postprocess(data: Tuple[Tensor, Tensor, Tensor, Any], **kwargs):
assert len(data) == 4
num_dets, bboxes, scores, labels = (i[0] for i in data)
nums = num_dets.item()
bboxes = bboxes[:nums]
scores = scores[:nums]
labels = labels[:nums]
return bboxes, scores, labels, None
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
@ -135,10 +224,19 @@ def parse_args():
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--seg', action='store_true', help='Seg inference')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='Confidence threshold')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='Confidence threshold')
parser.add_argument('--device',
type=str,
default='cuda:0',

@ -125,9 +125,50 @@ class PostDetect(nn.Module):
self.iou_thres, self.conf_thres, self.topk)
class PostSeg(nn.Module):
export = True
shape = None
dynamic = False
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size
mc = torch.cat(
[self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)],
2) # mask coefficients
box, score, cls = self.forward_det(x)
return box, score, cls, mc.transpose(1, 2), p.flatten(2)
def forward_det(self, x):
shape = x[0].shape
b, res = shape[0], []
for i in range(self.nl):
res.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(
0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x = [i.view(b, self.no, -1) for i in res]
y = torch.cat(x, 2)
box, cls = y[:, :self.reg_max * 4, ...], y[:, self.reg_max * 4:,
...].sigmoid()
box = box.view(b, 4, self.reg_max, -1).permute(0, 1, 3, 2).contiguous()
box = box.softmax(-1) @ torch.arange(self.reg_max).to(box)
box0, box1 = -box[:, :2, ...], box[:, 2:, ...]
box = self.anchors.repeat(b, 2, 1) + torch.cat([box0, box1], 1)
box = box * self.strides
score, cls = cls.transpose(1, 2).max(dim=-1)
return box.transpose(1, 2), score, cls
def optim(module: nn.Module):
s = str(type(module))[6:-2].split('.')[-1]
if s == 'Detect':
setattr(module, '__class__', PostDetect)
elif s == 'Segment':
setattr(module, '__class__', PostSeg)
elif s == 'C2f':
setattr(module, '__class__', C2f)

@ -237,6 +237,7 @@ class TRTModule(torch.nn.Module):
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
self.idx = list(range(self.num_outputs))
def __init_bindings(self) -> None:
dynamic = False
@ -270,6 +271,11 @@ class TRTModule(torch.nn.Module):
self.context.profiler = profiler \
if profiler is not None else trt.Profiler()
def set_desired(self, desired: Optional[Union[List, Tuple]]):
if isinstance(desired,
(list, tuple)) and len(desired) == self.num_outputs:
self.idx = [self.output_names.index(i) for i in desired]
def forward(self, *inputs) -> Union[Tuple, torch.Tensor]:
assert len(inputs) == self.num_inputs
@ -300,7 +306,8 @@ class TRTModule(torch.nn.Module):
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
self.stream.synchronize()
return tuple(outputs) if len(outputs) > 1 else outputs[0]
return tuple(outputs[i]
for i in self.idx) if len(outputs) > 1 else outputs[0]
class TRTProfilerV1(trt.IProfiler):

Loading…
Cancel
Save