Fix arbitrary imgsz for TFLite (#17138)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
pull/16962/head^2
Francesco Mattioli 4 weeks ago committed by GitHub
parent 6c12c1d69f
commit b0c18b7190
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      ultralytics/engine/exporter.py

@ -890,8 +890,10 @@ class Exporter:
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
if self.args.data:
f.mkdir()
images = [batch["img"].permute(0, 2, 3, 1) for batch in self.get_int8_calibration_dataloader(prefix)]
images = torch.cat(images, 0).float()
images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
0, 2, 3, 1
)
np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]

Loading…
Cancel
Save