You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
767 B
29 lines
767 B
from models import TRTModule, TRTProfilerV0 # isort:skip |
|
import argparse |
|
|
|
import torch |
|
|
|
|
|
def profile(args): |
|
device = torch.device(args.device) |
|
Engine = TRTModule(args.engine, device) |
|
profiler = TRTProfilerV0() |
|
Engine.set_profiler(profiler) |
|
random_input = torch.randn(Engine.inp_info[0].shape, device=device) |
|
_ = Engine(random_input) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--engine', type=str, help='Engine file') |
|
parser.add_argument('--device', |
|
type=str, |
|
default='cuda:0', |
|
help='TensorRT infer device') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
profile(args)
|
|
|