`ultralytics 8.1.46` add TensorRT 10 support (#9516)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: 九是否随意的称呼 <1069679911@qq.com>
pull/9949/head^2 v8.1.46
Burhan 10 months ago committed by GitHub
parent ea03db9984
commit 4ffd6ee6d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      tests/test_cuda.py
  2. 2
      ultralytics/__init__.py
  3. 29
      ultralytics/engine/exporter.py
  4. 71
      ultralytics/nn/autobackend.py

@ -19,6 +19,13 @@ def test_checks():
assert torch.cuda.is_available() == CUDA_IS_AVAILABLE
assert torch.cuda.device_count() == CUDA_DEVICE_COUNT
@pytest.mark.slow
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
def test_export_engine():
"""Test exporting the YOLO model to NVIDIA TensorRT format."""
f = YOLO(MODEL).export(format="engine", device=0)
YOLO(f)(BUS, device=0)
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
def test_train():

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.45"
__version__ = "8.1.46"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -658,6 +658,7 @@ class Exporter:
def export_engine(self, prefix=colorstr("TensorRT:")):
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
self.args.simplify = True
f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016
try:
@ -666,12 +667,10 @@ class Exporter:
if LINUX:
check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
import tensorrt as trt # noqa
check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
self.args.simplify = True
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
f = self.file.with_suffix(".engine") # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
@ -680,7 +679,11 @@ class Exporter:
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = int(self.args.workspace * (1 << 30))
workspace = int(self.args.workspace * (1 << 30))
if is_trt10:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
else: # TensorRT versions 7, 8
config.max_workspace_size = workspace
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
@ -699,27 +702,31 @@ class Exporter:
if shape[0] <= 1:
LOGGER.warning(f"{prefix} WARNING ⚠ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
profile = builder.create_optimization_profile()
min_shape = (1, shape[1], 32, 32) # minimum input shape
opt_shape = (max(1, shape[0] // 2), *shape[1:]) # optimal input shape
max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape
for inp in inputs:
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
LOGGER.info(
f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
)
if builder.platform_has_fast_fp16 and self.args.half:
half = builder.platform_has_fast_fp16 and self.args.half
LOGGER.info(f"{prefix} building FP{16 if half else 32} engine as {f}")
if half:
config.set_flag(trt.BuilderFlag.FP16)
# Free CUDA memory
del self.model
torch.cuda.empty_cache()
# Write file
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
build = builder.build_serialized_network if is_trt10 else builder.build_engine
with build(network, config) as engine, open(f, "wb") as t:
# Metadata
meta = json.dumps(self.metadata)
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
t.write(meta.encode())
# Model
t.write(engine.serialize())
t.write(engine if is_trt10 else engine.serialize())
return f, None

@ -234,23 +234,47 @@ class AutoBackend(nn.Module):
meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
model = runtime.deserialize_cuda_engine(f.read()) # read engine
context = model.create_execution_context()
# Model context
try:
context = model.create_execution_context()
except Exception as e: # model is None
LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n")
raise e
bindings = OrderedDict()
output_names = []
fp16 = False # default updated below
dynamic = False
for i in range(model.num_bindings):
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
else: # output
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
is_trt10 = not hasattr(model, "num_bindings")
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
for i in num:
if is_trt10:
name = model.get_tensor_name(i)
dtype = trt.nptype(model.get_tensor_dtype(name))
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
if is_input:
if -1 in tuple(model.get_tensor_shape(name)):
dynamic = True
context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
if dtype == np.float16:
fp16 = True
else:
output_names.append(name)
shape = tuple(context.get_tensor_shape(name))
else: # TensorRT < 10.0
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
is_input = model.binding_is_input(i)
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1]))
if dtype == np.float16:
fp16 = True
else:
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
@ -463,13 +487,20 @@ class AutoBackend(nn.Module):
# TensorRT
elif self.engine:
if self.dynamic and im.shape != self.bindings["images"].shape:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
if self.dynamic or im.shape != self.bindings["images"].shape:
if self.is_trt10:
self.context.set_input_shape("images", im.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))
else:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
s = self.bindings["images"].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs["images"] = int(im.data_ptr())

Loading…
Cancel
Save