Support TensorRT api build

pull/1/head
triple-Mu 2 years ago
parent 3bc0c4142a
commit 5574a29c69
  1. 25
      .pre-commit-config.yaml
  2. 101
      README.md
  3. 46
      build.py
  4. 93
      infer.py
  5. 2
      models/__init__.py
  6. 326
      models/api.py
  7. 240
      models/engine.py
  8. 1
      requirements.txt

@ -0,0 +1,25 @@
repos:
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: check-merge-conflict
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]

@ -1,10 +1,38 @@
# YOLOv8-TensorRT
YOLOv8 using TensorRT accelerate !
`YOLOv8` using TensorRT accelerate !
# Preprocessed ONNX model
# Prepare the environment
You can dowload the onnx model which is pretrained by https://github.com/ultralytics .
1. Install TensorRT follow [`TensorRT offical website`](https://developer.nvidia.com/nvidia-tensorrt-8x-download)
2. Install python requirement.
``` shell
pip install -r requirement.txt
```
3. (optional) Install `ultralytics YOLOv8` package for TensorRT API building.
``` shell
pip install -i https://test.pypi.org/simple/ ultralytics
```
You can download pretrained pytorch model by:
``` shell
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s.pt
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m.pt
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l.pt
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt
```
# Build TensorRT engine by ONNX
## Preprocessed ONNX model
You can dowload the onnx model which are exported by `YOLOv8` package and modified by me .
[**YOLOv8-n**](https://triplemu.oss-cn-beijing.aliyuncs.com/YOLOv8/ONNX/yolov8n_nms.onnx?OSSAccessKeyId=LTAI5tN1dgmZD4PF8AJUXp3J&Expires=1772936700&Signature=r6HgJTTcCSAxQxD9bKO9qBTtigQ%3D)
@ -16,23 +44,34 @@ You can dowload the onnx model which is pretrained by https://github.com/ultraly
[**YOLOv8-x**](https://triplemu.oss-cn-beijing.aliyuncs.com/YOLOv8/ONNX/yolov8x_nms.onnx?OSSAccessKeyId=LTAI5tN1dgmZD4PF8AJUXp3J&Expires=1673936778&Signature=3o%2F7QKhiZg1dW3I6sDrY4ug6MQU%3D)
# Build TensorRT engine by ONNX
## 1. By TensorRT Python api
## 1. By TensorRT ONNX Python api
You can export TensorRT engine by [`build.py` ](build.py).
You can export TensorRT engine from ONNX by [`build.py` ](build.py).
Usage:
``` shell
python3 build.py --onnx yolov8s_nms.onnx --device cuda:0 --fp16
python build.py \
--weights yolov8s_nms.onnx \
--iou-thres 0.65 \
--conf-thres 0.25 \
--topk 100 \
--fp16 \
--device cuda:0
```
#### Description of all arguments
- `--onnx` : The ONNX model you download.
- `--weights` : The ONNX model you download.
- `--iou-thres` : IOU threshold for NMS plugin.
- `--conf-thres` : Confidence threshold for NMS plugin.
- `--topk` : Max number of detection bboxes.
- `--fp16` : Whether to export half-precision engine.
- `--device` : The CUDA deivce you export engine .
- `--half` : Whether to export half-precision model.
You can modify `iou-thres` `conf-thres` `topk` by yourself.
## 2. By trtexec tools
@ -44,13 +83,37 @@ Usage:
/usr/src/tensorrt/bin/trtexec --onnx=yolov8s_nms.onnx --saveEngine=yolov8s_nms.engine --fp16
```
***If you installed TensorRT by a debian package, then the installation path of `trtexec`
is `/usr/src/tensorrt/bin/trtexec`***
**If you installed TensorRT by a debian package, then the installation path of `trtexec`
is `/usr/src/tensorrt/bin/trtexec`**
**If you installed TensorRT by a tar package, then the installation path of `trtexec` is under the `bin` folder in the path you decompressed**
# Build TensorRT engine by API
When you want to build engine by api. You should generate the pickle weights parameters first.
``` shell
python gen_pkl.py -w yolov8s.pt -o yolov8s.pkl
```
You will get a `yolov8s.pkl` which contain the operators' parameters. And you can rebuild `yolov8s` model in TensorRT api.
***If you installed TensorRT by a tar package, then the installation path of `trtexec` is under the `bin` folder in the
path you decompressed***
```
python build.py \
--weights yolov8s.pkl \
--iou-thres 0.65 \
--conf-thres 0.25 \
--topk 100 \
--fp16 \
--input-shape 1 3 640 640 \
--device cuda:0
```
# Infer images by the engine which you export
***Notice !!!*** Now we only support static input shape model build by TensorRT api. You'd best give the legal`input-shape`.
# Infer images by the engine which you export or build
You can infer images with the engine by [`infer.py`](infer.py) .
@ -63,17 +126,14 @@ python3 infer.py --engine yolov8s_nms.engine --imgs data --show --out-dir output
#### Description of all arguments
- `--engine` : The Engine you export.
- `--imgs` : The images path you want to detect.
- `--show` : Whether to show detection results.
- `--out-dir` : Where to save detection results images. It will not work when use `--show` flag.
- `--device` : The CUDA deivce you use.
- `--profile` : Profile the TensorRT engine.
# Profile you engine
If you want to profile the TensorRT engine:
Usage:
@ -81,4 +141,3 @@ Usage:
``` shell
python3 infer.py --engine yolov8s_nms.engine --profile
```

@ -1,21 +1,53 @@
import argparse
import os
from models import EngineBuilder
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--onnx', help='ONNX file')
parser.add_argument(
'--device', type=str, default='cuda:0', help='TensorRT builder device')
parser.add_argument(
'--fp16', action='store_true', help='Build model with fp16 mode')
parser.add_argument('--weights',
type=str,
required=True,
help='Weights file')
parser.add_argument('--iou-thres',
type=float,
default=0.65,
help='IOU threshoud for NMS plugin')
parser.add_argument('--conf-thres',
type=float,
default=0.25,
help='CONF threshoud for NMS plugin')
parser.add_argument('--topk',
type=int,
default=100,
help='Max number of detection bboxes')
parser.add_argument('--input-shape',
nargs='+',
type=int,
default=[1, 3, 640, 640],
help='Model input shape only for api builder')
parser.add_argument('--fp16',
action='store_true',
help='Build model with fp16 mode')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT builder device')
args = parser.parse_args()
assert len(args.input_shape) == 4
return args
def main(args):
builder = EngineBuilder(args.onnx, args.device)
builder.build(fp16=args.fp16)
builder = EngineBuilder(args.weights, args.device)
builder.build(fp16=args.fp16,
input_shape=args.input_shape,
iou_thres=args.iou_thres,
conf_thres=args.conf_thres,
topk=args.topk)
if __name__ == '__main__':

@ -1,25 +1,39 @@
from models import TRTModule, TRTProfilerV0
import argparse
import os
import random
from pathlib import Path
import cv2
import argparse
import numpy as np
import torch
import random
random.seed(0)
from models import TRTModule, TRTProfilerV0
SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff', '.webp', '.pfm')
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush')
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
random.seed(0)
COLORS = {cls: [random.randint(0, 255) for _ in range(3)] for i, cls in enumerate(CLASSES)}
SUFFIXS = ('.bmp', '.dng', '.jpeg', '.jpg', '.mpo', '.png', '.tif', '.tiff',
'.webp', '.pfm')
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush')
COLORS = {
cls: [random.randint(0, 255) for _ in range(3)]
for i, cls in enumerate(CLASSES)
}
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
@ -33,7 +47,8 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[
1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
@ -42,8 +57,16 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return im, np.array([r, r, r, r], dtype=np.float32), np.array([dw, dh, dw, dh], dtype=np.float32)
im = cv2.copyMakeBorder(im,
top,
bottom,
left,
right,
cv2.BORDER_CONSTANT,
value=color) # add border
return im, np.array([r, r, r, r],
dtype=np.float32), np.array([dw, dh, dw, dh],
dtype=np.float32)
def blob(im):
@ -62,7 +85,9 @@ def main(args):
save_path = Path(args.out_dir)
if images_path.is_dir():
images = [i.absolute() for i in images_path.iterdir() if i.suffix in SUFFIXS]
images = [
i.absolute() for i in images_path.iterdir() if i.suffix in SUFFIXS
]
else:
assert images_path.suffix in SUFFIXS
images = [images_path.absolute()]
@ -92,9 +117,11 @@ def main(args):
cls = CLASSES[cls_id]
color = COLORS[cls]
cv2.rectangle(draw, bbox[:2], bbox[2:], color, 2)
cv2.putText(draw, f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.75,
[225, 255, 255], thickness=2)
cv2.putText(draw,
f'{cls}:{score:.3f}', (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)
if args.show:
cv2.imshow('result', draw)
cv2.waitKey(0)
@ -106,14 +133,20 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--engine', type=str, help='Engine file')
parser.add_argument('--imgs', type=str, help='Images file')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--out-dir', type=str, default='./output', help='Path to output file')
parser.add_argument(
'--device', type=str, default='cuda:0', help='TensorRT infer device')
parser.add_argument(
'--profile', action='store_true', help='Profile TensorRT engine')
parser.add_argument('--show',
action='store_true',
help='Show the detection results')
parser.add_argument('--out-dir',
type=str,
default='./output',
help='Path to output file')
parser.add_argument('--device',
type=str,
default='cuda:0',
help='TensorRT infer device')
parser.add_argument('--profile',
action='store_true',
help='Profile TensorRT engine')
args = parser.parse_args()
return args

@ -1,3 +1,3 @@
from .engine import EngineBuilder, TRTModule, TRTProfilerV0, TRTProfilerV1
__all__ = ['EngineBuilder', 'TRTModule', 'TRTProfilerV0', 'TRTProfilerV1']
__all__ = ['EngineBuilder', 'TRTModule', 'TRTProfilerV0', 'TRTProfilerV1']

@ -0,0 +1,326 @@
import warnings
from typing import List, OrderedDict, Tuple, Union
import numpy as np
import tensorrt as trt
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
def trtweight(weights: np.ndarray) -> trt.Weights:
weights = weights.astype(weights.dtype.name)
return trt.Weights(weights)
def get_width(x: int, gw: float, divisor: int = 8) -> int:
return int(np.ceil(x * gw / divisor) * divisor)
def get_depth(x: int, gd: float) -> int:
return max(int(round(x * gd)), 1)
def Conv2d(network: trt.INetworkDefinition, weights: OrderedDict,
input: trt.ITensor, out_channel: int, ksize: int, stride: int,
group: int, layer_name: str) -> trt.ILayer:
padding = ksize // 2
conv_w = trtweight(weights[layer_name + '.weight'])
conv_b = trtweight(weights[layer_name + '.bias'])
conv = network.add_convolution_nd(input,
num_output_maps=out_channel,
kernel_shape=trt.DimsHW(ksize, ksize),
kernel=conv_w,
bias=conv_b)
assert conv, 'Add convolution_nd layer failed'
conv.stride_nd = trt.DimsHW(stride, stride)
conv.padding_nd = trt.DimsHW(padding, padding)
conv.num_groups = group
return conv
def Conv(network: trt.INetworkDefinition, weights: OrderedDict,
input: trt.ITensor, out_channel: int, ksize: int, stride: int,
group: int, layer_name: str) -> trt.ILayer:
padding = ksize // 2
if ksize > 3:
padding -= 1
conv_w = trtweight(weights[layer_name + '.conv.weight'])
conv_b = trtweight(weights[layer_name + '.conv.bias'])
conv = network.add_convolution_nd(input,
num_output_maps=out_channel,
kernel_shape=trt.DimsHW(ksize, ksize),
kernel=conv_w,
bias=conv_b)
assert conv, 'Add convolution_nd layer failed'
conv.stride_nd = trt.DimsHW(stride, stride)
conv.padding_nd = trt.DimsHW(padding, padding)
conv.num_groups = group
sigmoid = network.add_activation(conv.get_output(0),
trt.ActivationType.SIGMOID)
assert sigmoid, 'Add activation layer failed'
dot_product = network.add_elementwise(conv.get_output(0),
sigmoid.get_output(0),
trt.ElementWiseOperation.PROD)
assert dot_product, 'Add elementwise layer failed'
return dot_product
def Bottleneck(network: trt.INetworkDefinition, weights: OrderedDict,
input: trt.ITensor, c1: int, c2: int, shortcut: bool,
group: int, scale: float, layer_name: str) -> trt.ILayer:
c_ = int(c2 * scale)
conv1 = Conv(network, weights, input, c_, 3, 1, 1, layer_name + '.cv1')
conv2 = Conv(network, weights, conv1.get_output(0), c2, 3, 1, group,
layer_name + '.cv2')
if shortcut and c1 == c2:
ew = network.add_elementwise(input,
conv2.get_output(0),
op=trt.ElementWiseOperation.SUM)
assert ew, 'Add elementwise layer failed'
return ew
return conv2
def C2f(network: trt.INetworkDefinition, weights: OrderedDict,
input: trt.ITensor, cout: int, n: int, shortcut: bool, group: int,
scale: float, layer_name: str) -> trt.ILayer:
c_ = int(cout * scale) # e:expand param
conv1 = Conv(network, weights, input, 2 * c_, 1, 1, 1, layer_name + '.cv1')
y1 = conv1.get_output(0)
b, _, h, w = y1.shape
slice = network.add_slice(y1, (0, c_, 0, 0), (b, c_, h, w), (1, 1, 1, 1))
assert slice, 'Add slice layer failed'
y2 = slice.get_output(0)
input_tensors = [y1]
for i in range(n):
b = Bottleneck(network, weights, y2, c_, c_, shortcut, group, 1.0,
layer_name + '.m.' + str(i))
y2 = b.get_output(0)
input_tensors.append(y2)
cat = network.add_concatenation(input_tensors)
assert cat, 'Add concatenation layer failed'
conv2 = Conv(network, weights, cat.get_output(0), cout, 1, 1, 1,
layer_name + '.cv2')
return conv2
def SPPF(network: trt.INetworkDefinition, weights: OrderedDict,
input: trt.ITensor, c1: int, c2: int, ksize: int,
layer_name: str) -> trt.ILayer:
c_ = c1 // 2
conv1 = Conv(network, weights, input, c_, 1, 1, 1, layer_name + '.cv1')
pool1 = network.add_pooling_nd(conv1.get_output(0), trt.PoolingType.MAX,
trt.DimsHW(ksize, ksize))
assert pool1, 'Add pooling_nd layer failed'
pool1.padding_nd = trt.DimsHW(ksize // 2, ksize // 2)
pool1.stride_nd = trt.DimsHW(1, 1)
pool2 = network.add_pooling_nd(pool1.get_output(0), trt.PoolingType.MAX,
trt.DimsHW(ksize, ksize))
assert pool2, 'Add pooling_nd layer failed'
pool2.padding_nd = trt.DimsHW(ksize // 2, ksize // 2)
pool2.stride_nd = trt.DimsHW(1, 1)
pool3 = network.add_pooling_nd(pool2.get_output(0), trt.PoolingType.MAX,
trt.DimsHW(ksize, ksize))
assert pool3, 'Add pooling_nd layer failed'
pool3.padding_nd = trt.DimsHW(ksize // 2, ksize // 2)
pool3.stride_nd = trt.DimsHW(1, 1)
input_tensors = [
conv1.get_output(0),
pool1.get_output(0),
pool2.get_output(0),
pool3.get_output(0)
]
cat = network.add_concatenation(input_tensors)
assert cat, 'Add concatenation layer failed'
conv2 = Conv(network, weights, cat.get_output(0), c2, 1, 1, 1,
layer_name + '.cv2')
return conv2
def Detect(
network: trt.INetworkDefinition,
weights: OrderedDict,
input: Union[List, Tuple],
s: Union[List, Tuple],
layer_name: str,
reg_max: int = 16,
fp16: bool = True,
iou: float = 0.65,
conf: float = 0.25,
topk: int = 100,
) -> trt.ILayer:
bboxes_branch = []
scores_branch = []
anchors = []
strides = []
for i, (inp, stride) in enumerate(zip(input, s)):
h, w = inp.shape[2:]
sx = np.arange(0, w).astype(np.float16 if fp16 else np.float32) + 0.5
sy = np.arange(0, h).astype(np.float16 if fp16 else np.float32) + 0.5
sy, sx = np.meshgrid(sy, sx)
a = np.ascontiguousarray(np.stack((sy, sx), -1).reshape(-1, 2))
anchors.append(a)
strides.append(
np.full((1, h * w),
stride,
dtype=np.float16 if fp16 else np.float32))
c2 = weights[f'{layer_name}.cv2.{i}.0.conv.weight'].shape[0]
c3 = weights[f'{layer_name}.cv3.{i}.0.conv.weight'].shape[0]
nc = weights[f'{layer_name}.cv3.0.2.weight'].shape[0]
reg_max_x4 = weights[layer_name + f'.cv2.{i}.2.weight'].shape[0]
assert reg_max_x4 == reg_max * 4
b_Conv_0 = Conv(network, weights, inp, c2, 3, 1, 1,
layer_name + f'.cv2.{i}.0')
b_Conv_1 = Conv(network, weights, b_Conv_0.get_output(0), c2, 3, 1, 1,
layer_name + f'.cv2.{i}.1')
b_Conv_2 = Conv2d(network, weights, b_Conv_1.get_output(0), reg_max_x4,
1, 1, 1, layer_name + f'.cv2.{i}.2')
b_out = b_Conv_2.get_output(0)
b_shape = network.add_constant([
4,
], np.array(b_out.shape[0:1] + (4, reg_max, -1), dtype=np.int32))
assert b_shape, 'Add constant layer failed'
b_shuffle = network.add_shuffle(b_out)
assert b_shuffle, 'Add shuffle layer failed'
b_shuffle.set_input(1, b_shape.get_output(0))
b_shuffle.second_transpose = (0, 3, 1, 2)
bboxes_branch.append(b_shuffle.get_output(0))
s_Conv_0 = Conv(network, weights, inp, c3, 3, 1, 1,
layer_name + f'.cv3.{i}.0')
s_Conv_1 = Conv(network, weights, s_Conv_0.get_output(0), c3, 3, 1, 1,
layer_name + f'.cv3.{i}.1')
s_Conv_2 = Conv2d(network, weights, s_Conv_1.get_output(0), nc, 1, 1,
1, layer_name + f'.cv3.{i}.2')
s_out = s_Conv_2.get_output(0)
s_shape = network.add_constant([
3,
], np.array(s_out.shape[0:2] + (-1, ), dtype=np.int32))
assert s_shape, 'Add constant layer failed'
s_shuffle = network.add_shuffle(s_out)
assert s_shuffle, 'Add shuffle layer failed'
s_shuffle.set_input(1, s_shape.get_output(0))
s_shuffle.second_transpose = (0, 2, 1)
scores_branch.append(s_shuffle.get_output(0))
Cat_bboxes = network.add_concatenation(bboxes_branch)
assert Cat_bboxes, 'Add concatenation layer failed'
Cat_scores = network.add_concatenation(scores_branch)
assert Cat_scores, 'Add concatenation layer failed'
Cat_scores.axis = 1
Softmax = network.add_softmax(Cat_bboxes.get_output(0))
assert Softmax, 'Add softmax layer failed'
Softmax.axes = 1 << 3
SCORES = network.add_activation(Cat_scores.get_output(0),
trt.ActivationType.SIGMOID)
assert SCORES, 'Add activation layer failed'
reg_max = np.arange(
0, reg_max).astype(np.float16 if fp16 else np.float32).reshape(
(1, 1, -1, 1))
constant = network.add_constant(reg_max.shape, reg_max)
assert constant, 'Add constant layer failed'
Matmul = network.add_matrix_multiply(Softmax.get_output(0),
trt.MatrixOperation.NONE,
constant.get_output(0),
trt.MatrixOperation.NONE)
assert Matmul, 'Add matrix_multiply layer failed'
pre_bboxes = network.add_gather(
Matmul.get_output(0),
network.add_constant([
1,
], np.array([0], dtype=np.int32)).get_output(0), 3)
assert pre_bboxes, 'Add gather layer failed'
pre_bboxes.num_elementwise_dims = 1
pre_bboxes_tensor = pre_bboxes.get_output(0)
b, c, _ = pre_bboxes_tensor.shape
slice_x1y1 = network.add_slice(pre_bboxes_tensor, (0, 0, 0), (b, c, 2),
(1, 1, 1))
assert slice_x1y1, 'Add slice layer failed'
slice_x2y2 = network.add_slice(pre_bboxes_tensor, (0, 0, 2), (b, c, 2),
(1, 1, 1))
assert slice_x2y2, 'Add slice layer failed'
anchors = np.concatenate(anchors, 0)[np.newaxis]
anchors = network.add_constant(anchors.shape, anchors)
assert anchors, 'Add constant layer failed'
strides = np.concatenate(strides, 1)[..., np.newaxis]
strides = network.add_constant(strides.shape, strides)
assert strides, 'Add constant layer failed'
Sub = network.add_elementwise(anchors.get_output(0),
slice_x1y1.get_output(0),
trt.ElementWiseOperation.SUB)
assert Sub, 'Add elementwise layer failed'
Add = network.add_elementwise(anchors.get_output(0),
slice_x2y2.get_output(0),
trt.ElementWiseOperation.SUM)
assert Add, 'Add elementwise layer failed'
x1y1 = Sub.get_output(0)
x2y2 = Add.get_output(0)
Cat_bboxes_ = network.add_concatenation([x1y1, x2y2])
assert Cat_bboxes_, 'Add concatenation layer failed'
Cat_bboxes_.axis = 2
BBOXES = network.add_elementwise(Cat_bboxes_.get_output(0),
strides.get_output(0),
trt.ElementWiseOperation.PROD)
assert BBOXES, 'Add elementwise layer failed'
plugin_creator = trt.get_plugin_registry().get_plugin_creator(
'EfficientNMS_TRT', '1')
assert plugin_creator, 'Plugin EfficientNMS_TRT is not registried'
background_class = trt.PluginField('background_class',
np.array(-1, np.int32),
trt.PluginFieldType.INT32)
box_coding = trt.PluginField('box_coding', np.array(0, np.int32),
trt.PluginFieldType.INT32)
iou_threshold = trt.PluginField('iou_threshold',
np.array(iou, dtype=np.float32),
trt.PluginFieldType.FLOAT32)
max_output_boxes = trt.PluginField('max_output_boxes',
np.array(topk, np.int32),
trt.PluginFieldType.INT32)
plugin_version = trt.PluginField('plugin_version', np.array('1'),
trt.PluginFieldType.CHAR)
score_activation = trt.PluginField('score_activation',
np.array(0, np.int32),
trt.PluginFieldType.INT32)
score_threshold = trt.PluginField('score_threshold',
np.array(conf, dtype=np.float32),
trt.PluginFieldType.FLOAT32)
batched_nms_op = plugin_creator.create_plugin(
name='batched_nms',
field_collection=trt.PluginFieldCollection([
background_class, box_coding, iou_threshold, max_output_boxes,
plugin_version, score_activation, score_threshold
]))
batched_nms = network.add_plugin_v2(
inputs=[BBOXES.get_output(0),
SCORES.get_output(0)],
plugin=batched_nms_op)
batched_nms.get_output(0).name = 'num_dets'
batched_nms.get_output(1).name = 'bboxes'
batched_nms.get_output(2).name = 'scores'
batched_nms.get_output(3).name = 'labels'
return batched_nms

@ -1,12 +1,11 @@
import pickle
import warnings
from collections import defaultdict, namedtuple
from pathlib import Path
from typing import Optional, Union, List, Tuple
from collections import namedtuple, defaultdict
from typing import List, Optional, Tuple, Union
try:
import tensorrt as trt
except Exception:
trt = None
import warnings
import onnx
import tensorrt as trt
import torch
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
@ -14,9 +13,14 @@ warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class EngineBuilder:
def __init__(self, checkpoint: Union[str, Path], device: Optional[Union[str, int, torch.device]] = None) -> None:
checkpoint = Path(checkpoint) if isinstance(checkpoint, str) else checkpoint
assert checkpoint.exists() and checkpoint.suffix == '.onnx'
def __init__(
self,
checkpoint: Union[str, Path],
device: Optional[Union[str, int, torch.device]] = None) -> None:
checkpoint = Path(checkpoint) if isinstance(checkpoint,
str) else checkpoint
assert checkpoint.exists() and checkpoint.suffix in ('.onnx', '.pkl')
self.api = checkpoint.suffix == '.pkl'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
@ -25,46 +29,185 @@ class EngineBuilder:
self.checkpoint = checkpoint
self.device = device
def __build_engine(self, fp16: bool = True, with_profiling: bool = True) -> None:
def __build_engine(self,
fp16: bool = True,
input_shape: Union[List, Tuple] = (1, 3, 640, 640),
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100,
with_profiling: bool = True) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = torch.cuda.get_device_properties(self.device).total_memory
config.max_workspace_size = torch.cuda.get_device_properties(
self.device).total_memory
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(self.checkpoint)):
raise RuntimeError(f'failed to load ONNX file: {str(self.checkpoint)}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
logger.log(trt.Logger.WARNING, f'input "{inp.name}" with shape: {inp.shape} dtype: {inp.dtype}')
for out in outputs:
logger.log(trt.Logger.WARNING, f'output "{out.name}" with shape: {out.shape} dtype: {out.dtype}')
if fp16 and builder.platform_has_fast_fp16:
self.logger = logger
self.builder = builder
self.network = network
if self.api:
self.build_from_api(fp16, input_shape, iou_thres, conf_thres, topk)
else:
self.build_from_onnx(iou_thres, conf_thres, topk)
if fp16 and self.builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
self.weight = self.checkpoint.with_suffix('.engine')
if with_profiling:
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
with builder.build_engine(network, config) as engine:
with self.builder.build_engine(self.network, config) as engine:
self.weight.write_bytes(engine.serialize())
logger.log(trt.Logger.WARNING, f'Build tensorrt engine finish.\nSave in {str(self.weight.absolute())}')
self.logger.log(
trt.Logger.WARNING, f'Build tensorrt engine finish.\n'
f'Save in {str(self.weight.absolute())}')
def build(self,
fp16: bool = True,
input_shape: Union[List, Tuple] = (1, 3, 640, 640),
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100,
with_profiling=True) -> None:
self.__build_engine(fp16, input_shape, iou_thres, conf_thres, topk,
with_profiling)
def build_from_onnx(self,
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100):
parser = trt.OnnxParser(self.network, self.logger)
onnx_model = onnx.load(str(self.checkpoint))
onnx_model.graph.node[-1].attribute[2].i = topk
onnx_model.graph.node[-1].attribute[3].f = conf_thres
onnx_model.graph.node[-1].attribute[4].f = iou_thres
if not parser.parse(onnx_model.SerializeToString()):
raise RuntimeError(
f'failed to load ONNX file: {str(self.checkpoint)}')
inputs = [
self.network.get_input(i) for i in range(self.network.num_inputs)
]
outputs = [
self.network.get_output(i) for i in range(self.network.num_outputs)
]
def build(self, fp16: bool = True, with_profiling=True) -> None:
self.__build_engine(fp16, with_profiling)
for inp in inputs:
self.logger.log(
trt.Logger.WARNING,
f'input "{inp.name}" with shape: {inp.shape} '
f'dtype: {inp.dtype}')
for out in outputs:
self.logger.log(
trt.Logger.WARNING,
f'output "{out.name}" with shape: {out.shape} '
f'dtype: {out.dtype}')
def build_from_api(
self,
fp16: bool = True,
input_shape: Union[List, Tuple] = (1, 3, 640, 640),
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100,
):
from .api import SPPF, C2f, Conv, Detect, get_depth, get_width
with open(self.checkpoint, 'rb') as f:
state_dict = pickle.load(f)
mapping = {0.25: 1024, 0.5: 1024, 0.75: 768, 1.0: 512, 1.25: 512}
GW = state_dict['GW']
GD = state_dict['GD']
width_64 = get_width(64, GW)
width_128 = get_width(128, GW)
width_256 = get_width(256, GW)
width_512 = get_width(512, GW)
width_1024 = get_width(mapping[GW], GW)
depth_3 = get_depth(3, GD)
depth_6 = get_depth(6, GD)
strides = state_dict['strides']
reg_max = state_dict['reg_max']
images = self.network.add_input(name='images',
dtype=trt.float32,
shape=trt.Dims4(input_shape))
assert images, 'Add input failed'
Conv_0 = Conv(self.network, state_dict, images, width_64, 3, 2, 1,
'Conv.0')
Conv_1 = Conv(self.network, state_dict, Conv_0.get_output(0),
width_128, 3, 2, 1, 'Conv.1')
C2f_2 = C2f(self.network, state_dict, Conv_1.get_output(0), width_128,
depth_3, True, 1, 0.5, 'C2f.2')
Conv_3 = Conv(self.network, state_dict, C2f_2.get_output(0), width_256,
3, 2, 1, 'Conv.3')
C2f_4 = C2f(self.network, state_dict, Conv_3.get_output(0), width_256,
depth_6, True, 1, 0.5, 'C2f.4')
Conv_5 = Conv(self.network, state_dict, C2f_4.get_output(0), width_512,
3, 2, 1, 'Conv.5')
C2f_6 = C2f(self.network, state_dict, Conv_5.get_output(0), width_512,
depth_6, True, 1, 0.5, 'C2f.6')
Conv_7 = Conv(self.network, state_dict, C2f_6.get_output(0),
width_1024, 3, 2, 1, 'Conv.7')
C2f_8 = C2f(self.network, state_dict, Conv_7.get_output(0), width_1024,
depth_3, True, 1, 0.5, 'C2f.8')
SPPF_9 = SPPF(self.network, state_dict, C2f_8.get_output(0),
width_1024, width_1024, 5, 'SPPF.9')
Upsample_10 = self.network.add_resize(SPPF_9.get_output(0))
assert Upsample_10, 'Add Upsample_10 failed'
Upsample_10.resize_mode = trt.ResizeMode.NEAREST
Upsample_10.shape = Upsample_10.get_output(
0).shape[:2] + C2f_6.get_output(0).shape[2:]
input_tensors11 = [Upsample_10.get_output(0), C2f_6.get_output(0)]
Cat_11 = self.network.add_concatenation(input_tensors11)
C2f_12 = C2f(self.network, state_dict, Cat_11.get_output(0), width_512,
depth_3, False, 1, 0.5, 'C2f.12')
Upsample13 = self.network.add_resize(C2f_12.get_output(0))
assert Upsample13, 'Add Upsample13 failed'
Upsample13.resize_mode = trt.ResizeMode.NEAREST
Upsample13.shape = Upsample13.get_output(
0).shape[:2] + C2f_4.get_output(0).shape[2:]
input_tensors14 = [Upsample13.get_output(0), C2f_4.get_output(0)]
Cat_14 = self.network.add_concatenation(input_tensors14)
C2f_15 = C2f(self.network, state_dict, Cat_14.get_output(0), width_256,
depth_3, False, 1, 0.5, 'C2f.15')
Conv_16 = Conv(self.network, state_dict, C2f_15.get_output(0),
width_256, 3, 2, 1, 'Conv.16')
input_tensors17 = [Conv_16.get_output(0), C2f_12.get_output(0)]
Cat_17 = self.network.add_concatenation(input_tensors17)
C2f_18 = C2f(self.network, state_dict, Cat_17.get_output(0), width_512,
depth_3, False, 1, 0.5, 'C2f.18')
Conv_19 = Conv(self.network, state_dict, C2f_18.get_output(0),
width_512, 3, 2, 1, 'Conv.19')
input_tensors20 = [Conv_19.get_output(0), SPPF_9.get_output(0)]
Cat_20 = self.network.add_concatenation(input_tensors20)
C2f_21 = C2f(self.network, state_dict, Cat_20.get_output(0),
width_1024, depth_3, False, 1, 0.5, 'C2f.21')
input_tensors22 = [
C2f_15.get_output(0),
C2f_18.get_output(0),
C2f_21.get_output(0)
]
batched_nms = Detect(self.network, state_dict, input_tensors22,
strides, 'Detect.22', reg_max, fp16, iou_thres,
conf_thres, topk)
for o in range(batched_nms.num_outputs):
self.network.mark_output(batched_nms.get_output(o))
class TRTModule(torch.nn.Module):
dtypeMapping = {trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32}
def __init__(self, weight: Union[str, Path], device: Optional[torch.device]) -> None:
dtypeMapping = {
trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32
}
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]) -> None:
super(TRTModule, self).__init__()
self.weight = Path(weight) if isinstance(weight, str) else weight
self.device = device if device is not None else torch.device('cuda:0')
@ -107,7 +250,8 @@ class TRTModule(torch.nn.Module):
assert self.model.get_binding_name(i) == name
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape: dynamic = True
if -1 in shape:
dynamic = True
inp_info.append(Tensor(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
@ -117,23 +261,30 @@ class TRTModule(torch.nn.Module):
out_info.append(Tensor(name, dtype, shape))
if not dynamic:
self.output_tensor = [torch.empty(info.shape, dtype=info.dtype, device=self.device) for info in out_info]
self.output_tensor = [
torch.empty(info.shape, dtype=info.dtype, device=self.device)
for info in out_info
]
self.is_dynamic = dynamic
self.inp_info = inp_info
self.out_infp = out_info
def set_profiler(self, profiler: Optional[trt.IProfiler]):
self.context.profiler = profiler if profiler is not None else trt.Profiler()
self.context.profiler = profiler \
if profiler is not None else trt.Profiler()
def forward(self, *inputs) -> Union[Tuple, torch.Tensor]:
assert len(inputs) == self.num_inputs
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
contiguous_inputs: List[torch.Tensor] = [
i.contiguous() for i in inputs
]
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic:
self.context.set_binding_shape(i, tuple(contiguous_inputs[i].shape))
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
outputs: List[torch.Tensor] = []
@ -141,7 +292,9 @@ class TRTModule(torch.nn.Module):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(size=shape, dtype=self.out_info[i].dtype, device=self.device)
output = torch.empty(size=shape,
dtype=self.out_info[i].dtype,
device=self.device)
else:
output = self.output_tensor[i]
self.bindings[j] = output.data_ptr()
@ -154,6 +307,7 @@ class TRTModule(torch.nn.Module):
class TRTProfilerV1(trt.IProfiler):
def __init__(self):
trt.IProfiler.__init__(self)
self.total_runtime = 0.0
@ -167,14 +321,18 @@ class TRTProfilerV1(trt.IProfiler):
f = '\t%40s\t\t\t\t%10.4f'
print('\t%40s\t\t\t\t%10s' % ('layername', 'cost(us)'))
for name, cost in sorted(self.recorder.items(), key=lambda x: -x[1]):
print(f % (name if len(name) < 40 else name[:35] + ' ' + '*' * 4, cost))
print(
f %
(name if len(name) < 40 else name[:35] + ' ' + '*' * 4, cost))
print(f'\nTotal Inference Time: {self.total_runtime:.4f}(us)')
class TRTProfilerV0(trt.IProfiler):
def __init__(self):
trt.IProfiler.__init__(self)
def report_layer_time(self, layer_name: str, ms: float):
f = '\t%40s\t\t\t\t%10.4fms'
print(f % (layer_name if len(layer_name) < 40 else layer_name[:35] + ' ' + '*' * 4, ms))
print(f % (layer_name if len(layer_name) < 40 else layer_name[:35] +
' ' + '*' * 4, ms))

@ -2,3 +2,4 @@ numpy
opencv-python
torch
# tensorrt
# ultralytics -i https://test.pypi.org/simple

Loading…
Cancel
Save