|
|
@ -178,6 +178,16 @@ class Exporter: |
|
|
|
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases |
|
|
|
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases |
|
|
|
fmt = "coreml" |
|
|
|
fmt = "coreml" |
|
|
|
fmts = tuple(export_formats()["Argument"][1:]) # available export formats |
|
|
|
fmts = tuple(export_formats()["Argument"][1:]) # available export formats |
|
|
|
|
|
|
|
if fmt not in fmts: |
|
|
|
|
|
|
|
import difflib |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get the closest match if format is invalid |
|
|
|
|
|
|
|
matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match |
|
|
|
|
|
|
|
if closest_match: |
|
|
|
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ Invalid export format='{fmt}', updating to format='{matches[0]}'") |
|
|
|
|
|
|
|
fmt = closest_match[0] |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") |
|
|
|
flags = [x == fmt for x in fmts] |
|
|
|
flags = [x == fmt for x in fmts] |
|
|
|
if sum(flags) != 1: |
|
|
|
if sum(flags) != 1: |
|
|
|
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") |
|
|
|
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") |
|
|
|