|
|
|
@ -1,6 +1,6 @@ |
|
|
|
|
from pathlib import Path |
|
|
|
|
from typing import Optional, Union, List |
|
|
|
|
from collections import namedtuple |
|
|
|
|
from typing import Optional, Union, List, Tuple |
|
|
|
|
from collections import namedtuple, defaultdict |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
import tensorrt as trt |
|
|
|
@ -40,9 +40,9 @@ class EngineBuilder: |
|
|
|
|
outputs = [network.get_output(i) for i in range(network.num_outputs)] |
|
|
|
|
|
|
|
|
|
for inp in inputs: |
|
|
|
|
logger.log(trt.Logger.WARNING, f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') |
|
|
|
|
logger.log(trt.Logger.WARNING, f'input "{inp.name}" with shape: {inp.shape} dtype: {inp.dtype}') |
|
|
|
|
for out in outputs: |
|
|
|
|
logger.log(trt.Logger.WARNING, f'output "{out.name}" with shape{out.shape} {out.dtype}') |
|
|
|
|
logger.log(trt.Logger.WARNING, f'output "{out.name}" with shape: {out.shape} dtype: {out.dtype}') |
|
|
|
|
if fp16 and builder.platform_has_fast_fp16: |
|
|
|
|
config.set_flag(trt.BuilderFlag.FP16) |
|
|
|
|
self.weight = self.checkpoint.with_suffix('.engine') |
|
|
|
@ -53,7 +53,7 @@ class EngineBuilder: |
|
|
|
|
self.weight.write_bytes(engine.serialize()) |
|
|
|
|
logger.log(trt.Logger.WARNING, f'Build tensorrt engine finish.\nSave in {str(self.weight.absolute())}') |
|
|
|
|
|
|
|
|
|
def build(self, fp16: bool = True, with_profiling=True): |
|
|
|
|
def build(self, fp16: bool = True, with_profiling=True) -> None: |
|
|
|
|
self.__build_engine(fp16, with_profiling) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -64,7 +64,7 @@ class TRTModule(torch.nn.Module): |
|
|
|
|
trt.float16: torch.float16, |
|
|
|
|
trt.float32: torch.float32} |
|
|
|
|
|
|
|
|
|
def __init__(self, weight: Union[str, Path], device: Optional[torch.device]): |
|
|
|
|
def __init__(self, weight: Union[str, Path], device: Optional[torch.device]) -> None: |
|
|
|
|
super(TRTModule, self).__init__() |
|
|
|
|
self.weight = Path(weight) if isinstance(weight, str) else weight |
|
|
|
|
self.device = device if device is not None else torch.device('cuda:0') |
|
|
|
@ -72,7 +72,7 @@ class TRTModule(torch.nn.Module): |
|
|
|
|
self.__init_engine() |
|
|
|
|
self.__init_bindings() |
|
|
|
|
|
|
|
|
|
def __init_engine(self): |
|
|
|
|
def __init_engine(self) -> None: |
|
|
|
|
logger = trt.Logger(trt.Logger.WARNING) |
|
|
|
|
trt.init_libnvinfer_plugins(logger, namespace='') |
|
|
|
|
with trt.Runtime(logger) as runtime: |
|
|
|
@ -98,7 +98,7 @@ class TRTModule(torch.nn.Module): |
|
|
|
|
self.input_names = names[:num_inputs] |
|
|
|
|
self.output_names = names[num_inputs:] |
|
|
|
|
|
|
|
|
|
def __init_bindings(self): |
|
|
|
|
def __init_bindings(self) -> None: |
|
|
|
|
dynamic = False |
|
|
|
|
Tensor = namedtuple('Tensor', ('name', 'dtype', 'shape')) |
|
|
|
|
inp_info = [] |
|
|
|
@ -122,7 +122,10 @@ class TRTModule(torch.nn.Module): |
|
|
|
|
self.inp_info = inp_info |
|
|
|
|
self.out_infp = out_info |
|
|
|
|
|
|
|
|
|
def forward(self, *inputs): |
|
|
|
|
def set_profiler(self, profiler: Optional[trt.IProfiler]): |
|
|
|
|
self.context.profiler = profiler if profiler is not None else trt.Profiler() |
|
|
|
|
|
|
|
|
|
def forward(self, *inputs) -> Union[Tuple, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
assert len(inputs) == self.num_inputs |
|
|
|
|
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] |
|
|
|
@ -148,3 +151,30 @@ class TRTModule(torch.nn.Module): |
|
|
|
|
self.stream.synchronize() |
|
|
|
|
|
|
|
|
|
return tuple(outputs) if len(outputs) > 1 else outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TRTProfilerV1(trt.IProfiler): |
|
|
|
|
def __init__(self): |
|
|
|
|
trt.IProfiler.__init__(self) |
|
|
|
|
self.total_runtime = 0.0 |
|
|
|
|
self.recorder = defaultdict(float) |
|
|
|
|
|
|
|
|
|
def report_layer_time(self, layer_name: str, ms: float): |
|
|
|
|
self.total_runtime += ms * 1000 |
|
|
|
|
self.recorder[layer_name] += ms * 1000 |
|
|
|
|
|
|
|
|
|
def report(self): |
|
|
|
|
f = '\t%40s\t\t\t\t%10.4f' |
|
|
|
|
print('\t%40s\t\t\t\t%10s' % ('layername', 'cost(us)')) |
|
|
|
|
for name, cost in sorted(self.recorder.items(), key=lambda x: -x[1]): |
|
|
|
|
print(f % (name if len(name) < 40 else name[:35] + ' ' + '*' * 4, cost)) |
|
|
|
|
print(f'\nTotal Inference Time: {self.total_runtime:.4f}(us)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TRTProfilerV0(trt.IProfiler): |
|
|
|
|
def __init__(self): |
|
|
|
|
trt.IProfiler.__init__(self) |
|
|
|
|
|
|
|
|
|
def report_layer_time(self, layer_name: str, ms: float): |
|
|
|
|
f = '\t%40s\t\t\t\t%10.4fms' |
|
|
|
|
print(f % (layer_name if len(layer_name) < 40 else layer_name[:35] + ' ' + '*' * 4, ms)) |
|
|
|
|