Fix error on TensorRT export with float `workspace` value (#17352)

pull/17355/head^2
Mohammed Yasin 1 week ago committed by GitHub
parent 603fa84774
commit d0abd95f95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/engine/exporter.py

@ -791,7 +791,7 @@ class Exporter:
LOGGER.warning(f"{prefix} WARNING ⚠ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
profile = builder.create_optimization_profile()
min_shape = (1, shape[1], 32, 32) # minimum input shape
max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape
max_shape = (*shape[:2], *(int(max(1, self.args.workspace) * d) for d in shape[2:])) # max input shape
for inp in inputs:
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
config.add_optimization_profile(profile)

Loading…
Cancel
Save