pull/17463/head^2
UltralyticsAssistant 4 months ago
parent 16ac869266
commit b221a26601
  1. 24
      ultralytics/engine/exporter.py

@ -247,7 +247,9 @@ class Exporter:
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
if edgetpu or imx500:
if not LINUX:
raise SystemError("Edge TPU (https://coral.ai/docs/edgetpu/compiler) and IMX500 (https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter) export only supported on Linux.")
raise SystemError(
"Edge TPU (https://coral.ai/docs/edgetpu/compiler) and IMX500 (https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter) export only supported on Linux."
)
elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
LOGGER.warning("WARNING ⚠ Edge TPU export requires batch size 1, setting batch=1.")
self.args.batch = 1
@ -1113,16 +1115,19 @@ class Exporter:
raise ValueError("IMX500 export is only supported for YOLOv8 detection models")
check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers==0.2.0", "tensorflow==2.12.0"))
check_requirements("imx500-converter[pt]==3.14.3")
import subprocess
import model_compression_toolkit as mct
import onnx
import subprocess
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
try:
subprocess.run(["java", "--version"], check=True)
except EnvironmentError:
raise EnvironmentError("Java 17 is required for the imx500 conversion. \n Please install Java with: \n sudo apt install openjdk-17-jdk openjdk-17-jre")
except OSError:
raise OSError(
"Java 17 is required for the imx500 conversion. \n Please install Java with: \n sudo apt install openjdk-17-jdk openjdk-17-jre"
)
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
for batch in dataloader:
@ -1214,10 +1219,13 @@ class Exporter:
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)
output_dir = Path(str(self.file).replace(self.file.suffix, "_imx500_model"))
subprocess.run(["imxconv-pt", "-i", str(f), "-o", str(output_dir), "--no-input-persistency", "--overwrite-output"], check=True)
subprocess.run(
["imxconv-pt", "-i", str(f), "-o", str(output_dir), "--no-input-persistency", "--overwrite-output"],
check=True,
)
with open(output_dir / "labels.txt", "w") as file:
file.writelines([f"{name}\n" for _, name in self.model.names.items()])

Loading…
Cancel
Save