|
|
|
@ -101,23 +101,47 @@ from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_d |
|
|
|
|
def export_formats(): |
|
|
|
|
"""Ultralytics YOLO export formats.""" |
|
|
|
|
x = [ |
|
|
|
|
["PyTorch", "-", ".pt", True, True], |
|
|
|
|
["TorchScript", "torchscript", ".torchscript", True, True], |
|
|
|
|
["ONNX", "onnx", ".onnx", True, True], |
|
|
|
|
["OpenVINO", "openvino", "_openvino_model", True, False], |
|
|
|
|
["TensorRT", "engine", ".engine", False, True], |
|
|
|
|
["CoreML", "coreml", ".mlpackage", True, False], |
|
|
|
|
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True], |
|
|
|
|
["TensorFlow GraphDef", "pb", ".pb", True, True], |
|
|
|
|
["TensorFlow Lite", "tflite", ".tflite", True, False], |
|
|
|
|
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False], |
|
|
|
|
["TensorFlow.js", "tfjs", "_web_model", True, False], |
|
|
|
|
["PaddlePaddle", "paddle", "_paddle_model", True, True], |
|
|
|
|
["MNN", "mnn", ".mnn", True, True], |
|
|
|
|
["NCNN", "ncnn", "_ncnn_model", True, True], |
|
|
|
|
["IMX", "imx", "_imx_model", True, True], |
|
|
|
|
["PyTorch", "-", ".pt", True, True, []], |
|
|
|
|
["TorchScript", "torchscript", ".torchscript", True, True, ["optimize", "batch"]], |
|
|
|
|
["ONNX", "onnx", ".onnx", True, True, ["half", "dynamic", "simplify", "opset", "batch"]], |
|
|
|
|
["OpenVINO", "openvino", "_openvino_model", True, False, ["half", "int8", "batch"]], |
|
|
|
|
["TensorRT", "engine", ".engine", False, True, ["half", "dynamic", "simplify", "int8", "batch"]], |
|
|
|
|
["CoreML", "coreml", ".mlpackage", True, False, ["half", "int8", "nms", "batch"]], |
|
|
|
|
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["keras", "int8", "batch"]], |
|
|
|
|
["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]], |
|
|
|
|
["TensorFlow Lite", "tflite", ".tflite", True, False, ["half", "int8", "batch"]], |
|
|
|
|
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []], |
|
|
|
|
["TensorFlow.js", "tfjs", "_web_model", True, False, ["half", "int8", "batch"]], |
|
|
|
|
["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]], |
|
|
|
|
["MNN", "mnn", ".mnn", True, True, ["batch", "int8", "half"]], |
|
|
|
|
["NCNN", "ncnn", "_ncnn_model", True, True, ["half", "batch"]], |
|
|
|
|
["IMX", "imx", "_imx_model", True, True, ["int8"]], |
|
|
|
|
] |
|
|
|
|
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x))) |
|
|
|
|
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_args(format, passed_args, valid_args): |
|
|
|
|
""" |
|
|
|
|
Validates arguments based on format. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
format (str): The export format. |
|
|
|
|
passed_args (Namespace): The arguments used during export. |
|
|
|
|
valid_args (dict): List of valid arguments for the format. |
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
|
AssertionError: If an argument that's not supported by the export format is used, or if format doesn't have the supported arguments listed. |
|
|
|
|
""" |
|
|
|
|
# Only check valid usage of these args |
|
|
|
|
export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"] |
|
|
|
|
|
|
|
|
|
assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed." |
|
|
|
|
custom = {"batch": 1, "data": None, "device": None} # exporter defaults |
|
|
|
|
default_args = get_cfg(DEFAULT_CFG, custom) |
|
|
|
|
for arg in export_args: |
|
|
|
|
not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None) |
|
|
|
|
if not_default: |
|
|
|
|
assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gd_outputs(gd): |
|
|
|
@ -182,7 +206,8 @@ class Exporter: |
|
|
|
|
fmt = "engine" |
|
|
|
|
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases |
|
|
|
|
fmt = "coreml" |
|
|
|
|
fmts = tuple(export_formats()["Argument"][1:]) # available export formats |
|
|
|
|
fmts_dict = export_formats() |
|
|
|
|
fmts = tuple(fmts_dict["Argument"][1:]) # available export formats |
|
|
|
|
if fmt not in fmts: |
|
|
|
|
import difflib |
|
|
|
|
|
|
|
|
@ -224,7 +249,8 @@ class Exporter: |
|
|
|
|
assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}." |
|
|
|
|
self.device = select_device("cpu" if self.args.device is None else self.args.device) |
|
|
|
|
|
|
|
|
|
# Checks |
|
|
|
|
# Argument compatibility checks |
|
|
|
|
validate_args(fmt, self.args, fmts_dict["Arguments"][flags.index(True) + 1]) |
|
|
|
|
if imx and not self.args.int8: |
|
|
|
|
LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.") |
|
|
|
|
self.args.int8 = True |
|
|
|
|