You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.8 KiB
55 lines
1.8 KiB
import argparse |
|
import os |
|
|
|
from models import EngineBuilder |
|
|
|
os.environ['CUDA_MODULE_LOADING'] = 'LAZY' |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--weights', |
|
type=str, |
|
required=True, |
|
help='Weights file') |
|
parser.add_argument('--iou-thres', |
|
type=float, |
|
default=0.65, |
|
help='IOU threshoud for NMS plugin') |
|
parser.add_argument('--conf-thres', |
|
type=float, |
|
default=0.25, |
|
help='CONF threshoud for NMS plugin') |
|
parser.add_argument('--topk', |
|
type=int, |
|
default=100, |
|
help='Max number of detection bboxes') |
|
parser.add_argument('--input-shape', |
|
nargs='+', |
|
type=int, |
|
default=[1, 3, 640, 640], |
|
help='Model input shape only for api builder') |
|
parser.add_argument('--fp16', |
|
action='store_true', |
|
help='Build model with fp16 mode') |
|
parser.add_argument('--device', |
|
type=str, |
|
default='cuda:0', |
|
help='TensorRT builder device') |
|
args = parser.parse_args() |
|
assert len(args.input_shape) == 4 |
|
return args |
|
|
|
|
|
def main(args): |
|
builder = EngineBuilder(args.weights, args.device) |
|
builder.build(fp16=args.fp16, |
|
input_shape=args.input_shape, |
|
iou_thres=args.iou_thres, |
|
conf_thres=args.conf_thres, |
|
topk=args.topk) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
main(args)
|
|
|