From bf458ef709d00528b011b5daae3e4885c59e398b Mon Sep 17 00:00:00 2001 From: Sagyndyk Date: Thu, 8 Aug 2024 17:15:31 +0400 Subject: [PATCH] Update infer-det.py added flag dynamic, for using dynamic onnx --- infer-det.py | 176 +++++++++++++++++++++++++++++---------------------- 1 file changed, 101 insertions(+), 75 deletions(-) diff --git a/infer-det.py b/infer-det.py index cb54def..d505847 100644 --- a/infer-det.py +++ b/infer-det.py @@ -1,91 +1,117 @@ -from models import TRTModule # isort:skip import argparse -from pathlib import Path +from io import BytesIO -import cv2 +import onnx import torch +from ultralytics import YOLO -from config import CLASSES_DET, COLORS -from models.torch_utils import det_postprocess -from models.utils import blob, letterbox, path_to_list +from models.common import PostDetect, optim +try: + import onnxsim +except ImportError: + onnxsim = None -def main(args: argparse.Namespace) -> None: - device = torch.device(args.device) - Engine = TRTModule(args.engine, device) - H, W = Engine.inp_info[0].shape[-2:] - # set desired output names order - Engine.set_desired(['num_dets', 'bboxes', 'scores', 'labels']) - - images = path_to_list(args.imgs) - save_path = Path(args.out_dir) - - if not args.show and not save_path.exists(): - save_path.mkdir(parents=True, exist_ok=True) - - for image in images: - save_image = save_path / image.name - bgr = cv2.imread(str(image)) - draw = bgr.copy() - bgr, ratio, dwdh = letterbox(bgr, (W, H)) - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - tensor = blob(rgb, return_seg=False) - dwdh = torch.asarray(dwdh * 2, dtype=torch.float32, device=device) - tensor = torch.asarray(tensor, device=device) - # inference - data = Engine(tensor) - - bboxes, scores, labels = det_postprocess(data) - if bboxes.numel() == 0: - # if no bounding box - print(f'{image}: no object!') - continue - bboxes -= dwdh - bboxes /= ratio - - for (bbox, score, label) in zip(bboxes, scores, labels): - bbox = bbox.round().int().tolist() - cls_id = int(label) - cls = CLASSES_DET[cls_id] - color = COLORS[cls] - - text = f'{cls}:{score:.3f}' - x1, y1, x2, y2 = bbox - - (_w, _h), _bl = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 1) - _y1 = min(y1 + 1, draw.shape[0]) - - cv2.rectangle(draw, (x1, y1), (x2, y2), color, 2) - cv2.rectangle(draw, (x1, _y1), (x1 + _w, _y1 + _h + _bl), (0, 0, 255), -1) - cv2.putText(draw, text, (x1, _y1 + _h), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 255, 255), 2) - - if args.show: - cv2.imshow('result', draw) - cv2.waitKey(0) - else: - cv2.imwrite(str(save_image), draw) - - -def parse_args() -> argparse.Namespace: +def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--engine', type=str, help='Engine file') - parser.add_argument('--imgs', type=str, help='Images file') - parser.add_argument('--show', - action='store_true', - help='Show the detection results') - parser.add_argument('--out-dir', + parser.add_argument('-w', + '--weights', type=str, - default='./output', - help='Path to output file') + required=True, + help='PyTorch yolov8 weights') + parser.add_argument('--iou-thres', + type=float, + default=0.65, + help='IOU threshoud for NMS plugin') + parser.add_argument('--conf-thres', + type=float, + default=0.25, + help='CONF threshoud for NMS plugin') + parser.add_argument('--topk', + type=int, + default=100, + help='Max number of detection bboxes') + parser.add_argument('--opset', + type=int, + default=11, + help='ONNX opset version') + parser.add_argument('--sim', + action='store_true', + help='simplify onnx model') + parser.add_argument('--input-shape', + nargs='+', + type=int, + default=[1, 3, 640, 640], + help='Model input shape only for api builder') parser.add_argument('--device', type=str, - default='cuda:0', - help='TensorRT infer device') + default='cpu', + help='Export ONNX device') + parser.add_argument('--dynamic', + action='store_true', + help='Model input shape will dynamically' + ) args = parser.parse_args() + assert len(args.input_shape) == 4 + PostDetect.conf_thres = args.conf_thres + PostDetect.iou_thres = args.iou_thres + PostDetect.topk = args.topk return args +def main(args): + b = args.input_shape[0] + YOLOv8 = YOLO(args.weights) + model = YOLOv8.model.fuse().eval() + for m in model.modules(): + optim(m) + m.to(args.device) + model.to(args.device) + fake_input = torch.randn(args.input_shape).to(args.device) + for _ in range(2): + model(fake_input) + save_path = args.weights.replace('.pt', '.onnx') + with BytesIO() as f: + if args.dynamic: + torch.onnx.export( + model, + fake_input, + f, + opset_version=args.opset, + input_names=['images'], + output_names=['num_dets', 'bboxes', 'scores', 'labels'], + dynamic_axes={'images': {0: 'batch_size'}, + 'num_dets': {0: 'batch_size'}, + 'bboxes': {0: 'batch_size'}, + 'scores': {0: 'batch_size'}, + 'labels': {0: 'batch_size'}, + }) + else: + torch.onnx.export( + model, + fake_input, + f, + opset_version=args.opset, + input_names=['images'], + output_names=['num_dets', 'bboxes', 'scores', 'labels']) + f.seek(0) + onnx_model = onnx.load(f) + onnx.checker.check_model(onnx_model) + shapes = [b, 1, b, args.topk, 4, b, args.topk, b, args.topk] + if args.dynamic is False: + for i in onnx_model.graph.output: + for j in i.type.tensor_type.shape.dim: + j.dim_param = str(shapes.pop(0)) + if args.sim: + try: + onnx_model, check = onnxsim.simplify(onnx_model) + assert check, 'assert check failed' + except Exception as e: + print(f'Simplifier failure: {e}') + onnx.save(onnx_model, save_path) + print(f'ONNX export success, saved as {save_path}') + + if __name__ == '__main__': - args = parse_args() - main(args) + main(parse_args())