UltralyticsAssistant 2 months ago
parent eb6a8d0c31
commit 29d6493c0a
  1. 9
      ultralytics/engine/exporter.py
  2. 4
      ultralytics/nn/autobackend.py

@ -1055,9 +1055,10 @@ class Exporter:
@try_export
def export_mct(self, prefix=colorstr("Sony MCT:")):
check_requirements(["model_compression_toolkit==2.1.0", "sony-custom-layers[torch]"])
import subprocess
import model_compression_toolkit as mct
import onnx
import subprocess
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device, set_working_device
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
@ -1113,7 +1114,7 @@ class Exporter:
config = mct.core.CoreConfig(
mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True)
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
)
resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76)
@ -1172,7 +1173,9 @@ class Exporter:
try:
subprocess.run(["java", "--version"], check=True)
except FileNotFoundError:
LOGGER.error("Java 17 is required for the imx500 conversion. \n Please install Java with: \n sudo apt install openjdk-17-jdk openjdk-17-jre")
LOGGER.error(
"Java 17 is required for the imx500 conversion. \n Please install Java with: \n sudo apt install openjdk-17-jdk openjdk-17-jre"
)
return None
subprocess.run(["imxconv-pt", "-i", "yolov8n_mct_model.onnx", "-o", "yolov8n_imx500_model"], check=True)

@ -178,7 +178,9 @@ class AutoBackend(nn.Module):
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
if mct:
check_requirements(["model_compression_toolkit==2.1.0", "sony-custom-layers[torch]", "onnxruntime-extensions"])
check_requirements(
["model_compression_toolkit==2.1.0", "sony-custom-layers[torch]", "onnxruntime-extensions"]
)
LOGGER.info(f"Loading {w} for ONNX MCT quantization inference...")
import mct_quantizers as mctq
from sony_custom_layers.pytorch.object_detection import nms_ort # noqa

Loading…
Cancel
Save