import pickle import warnings from collections import defaultdict, namedtuple from pathlib import Path from typing import List, Optional, Tuple, Union import onnx import tensorrt as trt import torch warnings.filterwarnings(action='ignore', category=DeprecationWarning) class EngineBuilder: def __init__( self, checkpoint: Union[str, Path], device: Optional[Union[str, int, torch.device]] = None) -> None: checkpoint = Path(checkpoint) if isinstance(checkpoint, str) else checkpoint assert checkpoint.exists() and checkpoint.suffix in ('.onnx', '.pkl') self.api = checkpoint.suffix == '.pkl' if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device(f'cuda:{device}') self.checkpoint = checkpoint self.device = device def __build_engine(self, fp16: bool = True, input_shape: Union[List, Tuple] = (1, 3, 640, 640), iou_thres: float = 0.65, conf_thres: float = 0.25, topk: int = 100, with_profiling: bool = True) -> None: logger = trt.Logger(trt.Logger.WARNING) trt.init_libnvinfer_plugins(logger, namespace='') builder = trt.Builder(logger) config = builder.create_builder_config() config.max_workspace_size = torch.cuda.get_device_properties( self.device).total_memory flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) network = builder.create_network(flag) self.logger = logger self.builder = builder self.network = network if self.api: self.build_from_api(fp16, input_shape, iou_thres, conf_thres, topk) else: self.build_from_onnx(iou_thres, conf_thres, topk) if fp16 and self.builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) self.weight = self.checkpoint.with_suffix('.engine') if with_profiling: config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED with self.builder.build_engine(self.network, config) as engine: self.weight.write_bytes(engine.serialize()) self.logger.log( trt.Logger.WARNING, f'Build tensorrt engine finish.\n' f'Save in {str(self.weight.absolute())}') def build(self, fp16: bool = True, input_shape: Union[List, Tuple] = (1, 3, 640, 640), iou_thres: float = 0.65, conf_thres: float = 0.25, topk: int = 100, with_profiling=True) -> None: self.__build_engine(fp16, input_shape, iou_thres, conf_thres, topk, with_profiling) def build_from_onnx(self, iou_thres: float = 0.65, conf_thres: float = 0.25, topk: int = 100): parser = trt.OnnxParser(self.network, self.logger) onnx_model = onnx.load(str(self.checkpoint)) onnx_model.graph.node[-1].attribute[2].i = topk onnx_model.graph.node[-1].attribute[3].f = conf_thres onnx_model.graph.node[-1].attribute[4].f = iou_thres if not parser.parse(onnx_model.SerializeToString()): raise RuntimeError( f'failed to load ONNX file: {str(self.checkpoint)}') inputs = [ self.network.get_input(i) for i in range(self.network.num_inputs) ] outputs = [ self.network.get_output(i) for i in range(self.network.num_outputs) ] for inp in inputs: self.logger.log( trt.Logger.WARNING, f'input "{inp.name}" with shape: {inp.shape} ' f'dtype: {inp.dtype}') for out in outputs: self.logger.log( trt.Logger.WARNING, f'output "{out.name}" with shape: {out.shape} ' f'dtype: {out.dtype}') def build_from_api( self, fp16: bool = True, input_shape: Union[List, Tuple] = (1, 3, 640, 640), iou_thres: float = 0.65, conf_thres: float = 0.25, topk: int = 100, ): from .api import SPPF, C2f, Conv, Detect, get_depth, get_width with open(self.checkpoint, 'rb') as f: state_dict = pickle.load(f) mapping = {0.25: 1024, 0.5: 1024, 0.75: 768, 1.0: 512, 1.25: 512} GW = state_dict['GW'] GD = state_dict['GD'] width_64 = get_width(64, GW) width_128 = get_width(128, GW) width_256 = get_width(256, GW) width_512 = get_width(512, GW) width_1024 = get_width(mapping[GW], GW) depth_3 = get_depth(3, GD) depth_6 = get_depth(6, GD) strides = state_dict['strides'] reg_max = state_dict['reg_max'] images = self.network.add_input(name='images', dtype=trt.float32, shape=trt.Dims4(input_shape)) assert images, 'Add input failed' Conv_0 = Conv(self.network, state_dict, images, width_64, 3, 2, 1, 'Conv.0') Conv_1 = Conv(self.network, state_dict, Conv_0.get_output(0), width_128, 3, 2, 1, 'Conv.1') C2f_2 = C2f(self.network, state_dict, Conv_1.get_output(0), width_128, depth_3, True, 1, 0.5, 'C2f.2') Conv_3 = Conv(self.network, state_dict, C2f_2.get_output(0), width_256, 3, 2, 1, 'Conv.3') C2f_4 = C2f(self.network, state_dict, Conv_3.get_output(0), width_256, depth_6, True, 1, 0.5, 'C2f.4') Conv_5 = Conv(self.network, state_dict, C2f_4.get_output(0), width_512, 3, 2, 1, 'Conv.5') C2f_6 = C2f(self.network, state_dict, Conv_5.get_output(0), width_512, depth_6, True, 1, 0.5, 'C2f.6') Conv_7 = Conv(self.network, state_dict, C2f_6.get_output(0), width_1024, 3, 2, 1, 'Conv.7') C2f_8 = C2f(self.network, state_dict, Conv_7.get_output(0), width_1024, depth_3, True, 1, 0.5, 'C2f.8') SPPF_9 = SPPF(self.network, state_dict, C2f_8.get_output(0), width_1024, width_1024, 5, 'SPPF.9') Upsample_10 = self.network.add_resize(SPPF_9.get_output(0)) assert Upsample_10, 'Add Upsample_10 failed' Upsample_10.resize_mode = trt.ResizeMode.NEAREST Upsample_10.shape = Upsample_10.get_output( 0).shape[:2] + C2f_6.get_output(0).shape[2:] input_tensors11 = [Upsample_10.get_output(0), C2f_6.get_output(0)] Cat_11 = self.network.add_concatenation(input_tensors11) C2f_12 = C2f(self.network, state_dict, Cat_11.get_output(0), width_512, depth_3, False, 1, 0.5, 'C2f.12') Upsample13 = self.network.add_resize(C2f_12.get_output(0)) assert Upsample13, 'Add Upsample13 failed' Upsample13.resize_mode = trt.ResizeMode.NEAREST Upsample13.shape = Upsample13.get_output( 0).shape[:2] + C2f_4.get_output(0).shape[2:] input_tensors14 = [Upsample13.get_output(0), C2f_4.get_output(0)] Cat_14 = self.network.add_concatenation(input_tensors14) C2f_15 = C2f(self.network, state_dict, Cat_14.get_output(0), width_256, depth_3, False, 1, 0.5, 'C2f.15') Conv_16 = Conv(self.network, state_dict, C2f_15.get_output(0), width_256, 3, 2, 1, 'Conv.16') input_tensors17 = [Conv_16.get_output(0), C2f_12.get_output(0)] Cat_17 = self.network.add_concatenation(input_tensors17) C2f_18 = C2f(self.network, state_dict, Cat_17.get_output(0), width_512, depth_3, False, 1, 0.5, 'C2f.18') Conv_19 = Conv(self.network, state_dict, C2f_18.get_output(0), width_512, 3, 2, 1, 'Conv.19') input_tensors20 = [Conv_19.get_output(0), SPPF_9.get_output(0)] Cat_20 = self.network.add_concatenation(input_tensors20) C2f_21 = C2f(self.network, state_dict, Cat_20.get_output(0), width_1024, depth_3, False, 1, 0.5, 'C2f.21') input_tensors22 = [ C2f_15.get_output(0), C2f_18.get_output(0), C2f_21.get_output(0) ] batched_nms = Detect(self.network, state_dict, input_tensors22, strides, 'Detect.22', reg_max, fp16, iou_thres, conf_thres, topk) for o in range(batched_nms.num_outputs): self.network.mark_output(batched_nms.get_output(o)) class TRTModule(torch.nn.Module): dtypeMapping = { trt.bool: torch.bool, trt.int8: torch.int8, trt.int32: torch.int32, trt.float16: torch.float16, trt.float32: torch.float32 } def __init__(self, weight: Union[str, Path], device: Optional[torch.device]) -> None: super(TRTModule, self).__init__() self.weight = Path(weight) if isinstance(weight, str) else weight self.device = device if device is not None else torch.device('cuda:0') self.stream = torch.cuda.Stream(device=device) self.__init_engine() self.__init_bindings() def __init_engine(self) -> None: logger = trt.Logger(trt.Logger.WARNING) trt.init_libnvinfer_plugins(logger, namespace='') with trt.Runtime(logger) as runtime: model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) context = model.create_execution_context() names = [model.get_binding_name(i) for i in range(model.num_bindings)] self.num_bindings = model.num_bindings self.bindings: List[int] = [0] * self.num_bindings num_inputs, num_outputs = 0, 0 for i in range(model.num_bindings): if model.binding_is_input(i): num_inputs += 1 else: num_outputs += 1 self.num_inputs = num_inputs self.num_outputs = num_outputs self.model = model self.context = context self.input_names = names[:num_inputs] self.output_names = names[num_inputs:] def __init_bindings(self) -> None: dynamic = False Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape')) inp_info = [] out_info = [] for i, name in enumerate(self.input_names): assert self.model.get_binding_name(i) == name dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] shape = tuple(self.model.get_binding_shape(i)) if -1 in shape: dynamic = True inp_info.append(Tensor(name, dtype, shape)) for i, name in enumerate(self.output_names): i += self.num_inputs assert self.model.get_binding_name(i) == name dtype = self.dtypeMapping[self.model.get_binding_dtype(i)] shape = tuple(self.model.get_binding_shape(i)) out_info.append(Tensor(name, dtype, shape)) if not dynamic: self.output_tensor = [ torch.empty(info.shape, dtype=info.dtype, device=self.device) for info in out_info ] self.is_dynamic = dynamic self.inp_info = inp_info self.out_infp = out_info def set_profiler(self, profiler: Optional[trt.IProfiler]): self.context.profiler = profiler \ if profiler is not None else trt.Profiler() def forward(self, *inputs) -> Union[Tuple, torch.Tensor]: assert len(inputs) == self.num_inputs contiguous_inputs: List[torch.Tensor] = [ i.contiguous() for i in inputs ] for i in range(self.num_inputs): self.bindings[i] = contiguous_inputs[i].data_ptr() if self.is_dynamic: self.context.set_binding_shape( i, tuple(contiguous_inputs[i].shape)) outputs: List[torch.Tensor] = [] for i in range(self.num_outputs): j = i + self.num_inputs if self.is_dynamic: shape = tuple(self.context.get_binding_shape(j)) output = torch.empty(size=shape, dtype=self.out_info[i].dtype, device=self.device) else: output = self.output_tensor[i] self.bindings[j] = output.data_ptr() outputs.append(output) self.context.execute_async_v2(self.bindings, self.stream.cuda_stream) self.stream.synchronize() return tuple(outputs) if len(outputs) > 1 else outputs[0] class TRTProfilerV1(trt.IProfiler): def __init__(self): trt.IProfiler.__init__(self) self.total_runtime = 0.0 self.recorder = defaultdict(float) def report_layer_time(self, layer_name: str, ms: float): self.total_runtime += ms * 1000 self.recorder[layer_name] += ms * 1000 def report(self): f = '\t%40s\t\t\t\t%10.4f' print('\t%40s\t\t\t\t%10s' % ('layername', 'cost(us)')) for name, cost in sorted(self.recorder.items(), key=lambda x: -x[1]): print( f % (name if len(name) < 40 else name[:35] + ' ' + '*' * 4, cost)) print(f'\nTotal Inference Time: {self.total_runtime:.4f}(us)') class TRTProfilerV0(trt.IProfiler): def __init__(self): trt.IProfiler.__init__(self) def report_layer_time(self, layer_name: str, ms: float): f = '\t%40s\t\t\t\t%10.4fms' print(f % (layer_name if len(layer_name) < 40 else layer_name[:35] + ' ' + '*' * 4, ms))