|
|
|
@ -30,8 +30,8 @@ from ultralytics.utils import ( |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Define valid tasks and modes |
|
|
|
|
MODES = "train", "val", "predict", "export", "track", "benchmark" |
|
|
|
|
TASKS = "detect", "segment", "classify", "pose", "obb" |
|
|
|
|
MODES = {"train", "val", "predict", "export", "track", "benchmark"} |
|
|
|
|
TASKS = {"detect", "segment", "classify", "pose", "obb"} |
|
|
|
|
TASK2DATA = { |
|
|
|
|
"detect": "coco8.yaml", |
|
|
|
|
"segment": "coco8-seg.yaml", |
|
|
|
@ -93,8 +93,8 @@ CLI_HELP_MSG = f""" |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# Define keys for arg type checks |
|
|
|
|
CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time" |
|
|
|
|
CFG_FRACTION_KEYS = ( |
|
|
|
|
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"} |
|
|
|
|
CFG_FRACTION_KEYS = { |
|
|
|
|
"dropout", |
|
|
|
|
"iou", |
|
|
|
|
"lr0", |
|
|
|
@ -118,8 +118,8 @@ CFG_FRACTION_KEYS = ( |
|
|
|
|
"conf", |
|
|
|
|
"iou", |
|
|
|
|
"fraction", |
|
|
|
|
) # fraction floats 0.0 - 1.0 |
|
|
|
|
CFG_INT_KEYS = ( |
|
|
|
|
} # fraction floats 0.0 - 1.0 |
|
|
|
|
CFG_INT_KEYS = { |
|
|
|
|
"epochs", |
|
|
|
|
"patience", |
|
|
|
|
"batch", |
|
|
|
@ -133,8 +133,8 @@ CFG_INT_KEYS = ( |
|
|
|
|
"workspace", |
|
|
|
|
"nbs", |
|
|
|
|
"save_period", |
|
|
|
|
) |
|
|
|
|
CFG_BOOL_KEYS = ( |
|
|
|
|
} |
|
|
|
|
CFG_BOOL_KEYS = { |
|
|
|
|
"save", |
|
|
|
|
"exist_ok", |
|
|
|
|
"verbose", |
|
|
|
@ -169,7 +169,7 @@ CFG_BOOL_KEYS = ( |
|
|
|
|
"nms", |
|
|
|
|
"profile", |
|
|
|
|
"multi_scale", |
|
|
|
|
) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cfg2dict(cfg): |
|
|
|
@ -219,33 +219,46 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") |
|
|
|
|
|
|
|
|
|
# Type and Value checks |
|
|
|
|
check_cfg(cfg) |
|
|
|
|
|
|
|
|
|
# Return instance |
|
|
|
|
return IterableSimpleNamespace(**cfg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_cfg(cfg, hard=True): |
|
|
|
|
"""Check Ultralytics configuration argument types and values.""" |
|
|
|
|
for k, v in cfg.items(): |
|
|
|
|
if v is not None: # None values may be from optional args |
|
|
|
|
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): |
|
|
|
|
raise TypeError( |
|
|
|
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
|
|
|
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" |
|
|
|
|
) |
|
|
|
|
elif k in CFG_FRACTION_KEYS: |
|
|
|
|
if not isinstance(v, (int, float)): |
|
|
|
|
if hard: |
|
|
|
|
raise TypeError( |
|
|
|
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
|
|
|
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" |
|
|
|
|
) |
|
|
|
|
cfg[k] = float(v) |
|
|
|
|
elif k in CFG_FRACTION_KEYS: |
|
|
|
|
if not isinstance(v, (int, float)): |
|
|
|
|
if hard: |
|
|
|
|
raise TypeError( |
|
|
|
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
|
|
|
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" |
|
|
|
|
) |
|
|
|
|
cfg[k] = float(v) |
|
|
|
|
if not (0.0 <= v <= 1.0): |
|
|
|
|
raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") |
|
|
|
|
elif k in CFG_INT_KEYS and not isinstance(v, int): |
|
|
|
|
raise TypeError( |
|
|
|
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" |
|
|
|
|
) |
|
|
|
|
if hard: |
|
|
|
|
raise TypeError( |
|
|
|
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" |
|
|
|
|
) |
|
|
|
|
cfg[k] = int(v) |
|
|
|
|
elif k in CFG_BOOL_KEYS and not isinstance(v, bool): |
|
|
|
|
raise TypeError( |
|
|
|
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
|
|
|
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Return instance |
|
|
|
|
return IterableSimpleNamespace(**cfg) |
|
|
|
|
if hard: |
|
|
|
|
raise TypeError( |
|
|
|
|
f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
|
|
|
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" |
|
|
|
|
) |
|
|
|
|
cfg[k] = bool(v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_save_dir(args, name=None): |
|
|
|
@ -464,10 +477,10 @@ def entrypoint(debug=""): |
|
|
|
|
overrides = {} # basic overrides, i.e. imgsz=320 |
|
|
|
|
for a in merge_equals_args(args): # merge spaces around '=' sign |
|
|
|
|
if a.startswith("--"): |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") |
|
|
|
|
a = a[2:] |
|
|
|
|
if a.endswith(","): |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") |
|
|
|
|
a = a[:-1] |
|
|
|
|
if "=" in a: |
|
|
|
|
try: |
|
|
|
@ -504,7 +517,7 @@ def entrypoint(debug=""): |
|
|
|
|
mode = overrides.get("mode") |
|
|
|
|
if mode is None: |
|
|
|
|
mode = DEFAULT_CFG.mode or "predict" |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") |
|
|
|
|
elif mode not in MODES: |
|
|
|
|
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") |
|
|
|
|
|
|
|
|
@ -520,7 +533,7 @@ def entrypoint(debug=""): |
|
|
|
|
model = overrides.pop("model", DEFAULT_CFG.model) |
|
|
|
|
if model is None: |
|
|
|
|
model = "yolov8n.pt" |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") |
|
|
|
|
overrides["model"] = model |
|
|
|
|
stem = Path(model).stem.lower() |
|
|
|
|
if "rtdetr" in stem: # guess architecture |
|
|
|
@ -554,15 +567,15 @@ def entrypoint(debug=""): |
|
|
|
|
# Mode |
|
|
|
|
if mode in ("predict", "track") and "source" not in overrides: |
|
|
|
|
overrides["source"] = DEFAULT_CFG.source or ASSETS |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") |
|
|
|
|
elif mode in ("train", "val"): |
|
|
|
|
if "data" not in overrides and "resume" not in overrides: |
|
|
|
|
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") |
|
|
|
|
elif mode == "export": |
|
|
|
|
if "format" not in overrides: |
|
|
|
|
overrides["format"] = DEFAULT_CFG.format or "torchscript" |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.") |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") |
|
|
|
|
|
|
|
|
|
# Run command in python |
|
|
|
|
getattr(model, mode)(**overrides) # default args from model |
|
|
|
|