You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

338 lines
14 KiB

import pickle
import warnings
from collections import defaultdict, namedtuple
from pathlib import Path
from typing import List, Optional, Tuple, Union
import onnx
import tensorrt as trt
import torch
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class EngineBuilder:
def __init__(
self,
checkpoint: Union[str, Path],
device: Optional[Union[str, int, torch.device]] = None) -> None:
checkpoint = Path(checkpoint) if isinstance(checkpoint,
str) else checkpoint
assert checkpoint.exists() and checkpoint.suffix in ('.onnx', '.pkl')
self.api = checkpoint.suffix == '.pkl'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.checkpoint = checkpoint
self.device = device
def __build_engine(self,
fp16: bool = True,
input_shape: Union[List, Tuple] = (1, 3, 640, 640),
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100,
with_profiling: bool = True) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = torch.cuda.get_device_properties(
self.device).total_memory
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
self.logger = logger
self.builder = builder
self.network = network
if self.api:
self.build_from_api(fp16, input_shape, iou_thres, conf_thres, topk)
else:
self.build_from_onnx(iou_thres, conf_thres, topk)
if fp16 and self.builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
self.weight = self.checkpoint.with_suffix('.engine')
if with_profiling:
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
with self.builder.build_engine(self.network, config) as engine:
self.weight.write_bytes(engine.serialize())
self.logger.log(
trt.Logger.WARNING, f'Build tensorrt engine finish.\n'
f'Save in {str(self.weight.absolute())}')
def build(self,
fp16: bool = True,
input_shape: Union[List, Tuple] = (1, 3, 640, 640),
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100,
with_profiling=True) -> None:
self.__build_engine(fp16, input_shape, iou_thres, conf_thres, topk,
with_profiling)
def build_from_onnx(self,
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100):
parser = trt.OnnxParser(self.network, self.logger)
onnx_model = onnx.load(str(self.checkpoint))
onnx_model.graph.node[-1].attribute[2].i = topk
onnx_model.graph.node[-1].attribute[3].f = conf_thres
onnx_model.graph.node[-1].attribute[4].f = iou_thres
if not parser.parse(onnx_model.SerializeToString()):
raise RuntimeError(
f'failed to load ONNX file: {str(self.checkpoint)}')
inputs = [
self.network.get_input(i) for i in range(self.network.num_inputs)
]
outputs = [
self.network.get_output(i) for i in range(self.network.num_outputs)
]
for inp in inputs:
self.logger.log(
trt.Logger.WARNING,
f'input "{inp.name}" with shape: {inp.shape} '
f'dtype: {inp.dtype}')
for out in outputs:
self.logger.log(
trt.Logger.WARNING,
f'output "{out.name}" with shape: {out.shape} '
f'dtype: {out.dtype}')
def build_from_api(
self,
fp16: bool = True,
input_shape: Union[List, Tuple] = (1, 3, 640, 640),
iou_thres: float = 0.65,
conf_thres: float = 0.25,
topk: int = 100,
):
from .api import SPPF, C2f, Conv, Detect, get_depth, get_width
with open(self.checkpoint, 'rb') as f:
state_dict = pickle.load(f)
mapping = {0.25: 1024, 0.5: 1024, 0.75: 768, 1.0: 512, 1.25: 512}
GW = state_dict['GW']
GD = state_dict['GD']
width_64 = get_width(64, GW)
width_128 = get_width(128, GW)
width_256 = get_width(256, GW)
width_512 = get_width(512, GW)
width_1024 = get_width(mapping[GW], GW)
depth_3 = get_depth(3, GD)
depth_6 = get_depth(6, GD)
strides = state_dict['strides']
reg_max = state_dict['reg_max']
images = self.network.add_input(name='images',
dtype=trt.float32,
shape=trt.Dims4(input_shape))
assert images, 'Add input failed'
Conv_0 = Conv(self.network, state_dict, images, width_64, 3, 2, 1,
'Conv.0')
Conv_1 = Conv(self.network, state_dict, Conv_0.get_output(0),
width_128, 3, 2, 1, 'Conv.1')
C2f_2 = C2f(self.network, state_dict, Conv_1.get_output(0), width_128,
depth_3, True, 1, 0.5, 'C2f.2')
Conv_3 = Conv(self.network, state_dict, C2f_2.get_output(0), width_256,
3, 2, 1, 'Conv.3')
C2f_4 = C2f(self.network, state_dict, Conv_3.get_output(0), width_256,
depth_6, True, 1, 0.5, 'C2f.4')
Conv_5 = Conv(self.network, state_dict, C2f_4.get_output(0), width_512,
3, 2, 1, 'Conv.5')
C2f_6 = C2f(self.network, state_dict, Conv_5.get_output(0), width_512,
depth_6, True, 1, 0.5, 'C2f.6')
Conv_7 = Conv(self.network, state_dict, C2f_6.get_output(0),
width_1024, 3, 2, 1, 'Conv.7')
C2f_8 = C2f(self.network, state_dict, Conv_7.get_output(0), width_1024,
depth_3, True, 1, 0.5, 'C2f.8')
SPPF_9 = SPPF(self.network, state_dict, C2f_8.get_output(0),
width_1024, width_1024, 5, 'SPPF.9')
Upsample_10 = self.network.add_resize(SPPF_9.get_output(0))
assert Upsample_10, 'Add Upsample_10 failed'
Upsample_10.resize_mode = trt.ResizeMode.NEAREST
Upsample_10.shape = Upsample_10.get_output(
0).shape[:2] + C2f_6.get_output(0).shape[2:]
input_tensors11 = [Upsample_10.get_output(0), C2f_6.get_output(0)]
Cat_11 = self.network.add_concatenation(input_tensors11)
C2f_12 = C2f(self.network, state_dict, Cat_11.get_output(0), width_512,
depth_3, False, 1, 0.5, 'C2f.12')
Upsample13 = self.network.add_resize(C2f_12.get_output(0))
assert Upsample13, 'Add Upsample13 failed'
Upsample13.resize_mode = trt.ResizeMode.NEAREST
Upsample13.shape = Upsample13.get_output(
0).shape[:2] + C2f_4.get_output(0).shape[2:]
input_tensors14 = [Upsample13.get_output(0), C2f_4.get_output(0)]
Cat_14 = self.network.add_concatenation(input_tensors14)
C2f_15 = C2f(self.network, state_dict, Cat_14.get_output(0), width_256,
depth_3, False, 1, 0.5, 'C2f.15')
Conv_16 = Conv(self.network, state_dict, C2f_15.get_output(0),
width_256, 3, 2, 1, 'Conv.16')
input_tensors17 = [Conv_16.get_output(0), C2f_12.get_output(0)]
Cat_17 = self.network.add_concatenation(input_tensors17)
C2f_18 = C2f(self.network, state_dict, Cat_17.get_output(0), width_512,
depth_3, False, 1, 0.5, 'C2f.18')
Conv_19 = Conv(self.network, state_dict, C2f_18.get_output(0),
width_512, 3, 2, 1, 'Conv.19')
input_tensors20 = [Conv_19.get_output(0), SPPF_9.get_output(0)]
Cat_20 = self.network.add_concatenation(input_tensors20)
C2f_21 = C2f(self.network, state_dict, Cat_20.get_output(0),
width_1024, depth_3, False, 1, 0.5, 'C2f.21')
input_tensors22 = [
C2f_15.get_output(0),
C2f_18.get_output(0),
C2f_21.get_output(0)
]
batched_nms = Detect(self.network, state_dict, input_tensors22,
strides, 'Detect.22', reg_max, fp16, iou_thres,
conf_thres, topk)
for o in range(batched_nms.num_outputs):
self.network.mark_output(batched_nms.get_output(o))
class TRTModule(torch.nn.Module):
dtypeMapping = {
trt.bool: torch.bool,
trt.int8: torch.int8,
trt.int32: torch.int32,
trt.float16: torch.float16,
trt.float32: torch.float32
}
def __init__(self, weight: Union[str, Path],
device: Optional[torch.device]) -> None:
super(TRTModule, self).__init__()
self.weight = Path(weight) if isinstance(weight, str) else weight
self.device = device if device is not None else torch.device('cuda:0')
self.stream = torch.cuda.Stream(device=device)
self.__init_engine()
self.__init_bindings()
def __init_engine(self) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
context = model.create_execution_context()
names = [model.get_binding_name(i) for i in range(model.num_bindings)]
self.num_bindings = model.num_bindings
self.bindings: List[int] = [0] * self.num_bindings
num_inputs, num_outputs = 0, 0
for i in range(model.num_bindings):
if model.binding_is_input(i):
num_inputs += 1
else:
num_outputs += 1
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.model = model
self.context = context
self.input_names = names[:num_inputs]
self.output_names = names[num_inputs:]
def __init_bindings(self) -> None:
dynamic = False
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape'))
inp_info = []
out_info = []
for i, name in enumerate(self.input_names):
assert self.model.get_binding_name(i) == name
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
if -1 in shape:
dynamic = True
inp_info.append(Tensor(name, dtype, shape))
for i, name in enumerate(self.output_names):
i += self.num_inputs
assert self.model.get_binding_name(i) == name
dtype = self.dtypeMapping[self.model.get_binding_dtype(i)]
shape = tuple(self.model.get_binding_shape(i))
out_info.append(Tensor(name, dtype, shape))
if not dynamic:
self.output_tensor = [
torch.empty(info.shape, dtype=info.dtype, device=self.device)
for info in out_info
]
self.is_dynamic = dynamic
self.inp_info = inp_info
self.out_infp = out_info
def set_profiler(self, profiler: Optional[trt.IProfiler]):
self.context.profiler = profiler \
if profiler is not None else trt.Profiler()
def forward(self, *inputs) -> Union[Tuple, torch.Tensor]:
assert len(inputs) == self.num_inputs
contiguous_inputs: List[torch.Tensor] = [
i.contiguous() for i in inputs
]
for i in range(self.num_inputs):
self.bindings[i] = contiguous_inputs[i].data_ptr()
if self.is_dynamic:
self.context.set_binding_shape(
i, tuple(contiguous_inputs[i].shape))
outputs: List[torch.Tensor] = []
for i in range(self.num_outputs):
j = i + self.num_inputs
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(j))
output = torch.empty(size=shape,
dtype=self.out_info[i].dtype,
device=self.device)
else:
output = self.output_tensor[i]
self.bindings[j] = output.data_ptr()
outputs.append(output)
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream)
self.stream.synchronize()
return tuple(outputs) if len(outputs) > 1 else outputs[0]
class TRTProfilerV1(trt.IProfiler):
def __init__(self):
trt.IProfiler.__init__(self)
self.total_runtime = 0.0
self.recorder = defaultdict(float)
def report_layer_time(self, layer_name: str, ms: float):
self.total_runtime += ms * 1000
self.recorder[layer_name] += ms * 1000
def report(self):
f = '\t%40s\t\t\t\t%10.4f'
print('\t%40s\t\t\t\t%10s' % ('layername', 'cost(us)'))
for name, cost in sorted(self.recorder.items(), key=lambda x: -x[1]):
print(
f %
(name if len(name) < 40 else name[:35] + ' ' + '*' * 4, cost))
print(f'\nTotal Inference Time: {self.total_runtime:.4f}(us)')
class TRTProfilerV0(trt.IProfiler):
def __init__(self):
trt.IProfiler.__init__(self)
def report_layer_time(self, layer_name: str, ms: float):
f = '\t%40s\t\t\t\t%10.4fms'
print(f % (layer_name if len(layer_name) < 40 else layer_name[:35] +
' ' + '*' * 4, ms))