diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 8a2765c98..4a40a8829 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -136,8 +136,8 @@ class BaseValidator: if engine: self.args.batch = model.batch_size elif not pt and not jit: - self.args.batch = 1 # export.py models default to batch-size 1 - LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") + self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1 + LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})") if str(self.args.data).split(".")[-1] in {"yaml", "yml"}: self.data = check_det_dataset(self.args.data)