You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
767 B
30 lines
767 B
2 years ago
|
from models import TRTModule, TRTProfilerV0 # isort:skip
|
||
|
import argparse
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def profile(args):
|
||
|
device = torch.device(args.device)
|
||
|
Engine = TRTModule(args.engine, device)
|
||
|
profiler = TRTProfilerV0()
|
||
|
Engine.set_profiler(profiler)
|
||
|
random_input = torch.randn(Engine.inp_info[0].shape, device=device)
|
||
|
_ = Engine(random_input)
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--engine', type=str, help='Engine file')
|
||
|
parser.add_argument('--device',
|
||
|
type=str,
|
||
|
default='cuda:0',
|
||
|
help='TensorRT infer device')
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = parse_args()
|
||
|
profile(args)
|