Exported model batch size validation fix (#14845)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8620/merge
Francesco Mattioli 4 months ago committed by GitHub
parent 8648572809
commit db4c43bafb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      ultralytics/engine/validator.py

@ -136,8 +136,8 @@ class BaseValidator:
if engine:
self.args.batch = model.batch_size
elif not pt and not jit:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
self.data = check_det_dataset(self.args.data)

Loading…
Cancel
Save