|
|
@ -8,9 +8,9 @@ from pathlib import Path |
|
|
|
from types import SimpleNamespace |
|
|
|
from types import SimpleNamespace |
|
|
|
from typing import Dict, List, Union |
|
|
|
from typing import Dict, List, Union |
|
|
|
|
|
|
|
|
|
|
|
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML, |
|
|
|
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, |
|
|
|
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load, |
|
|
|
SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, |
|
|
|
yaml_print) |
|
|
|
yaml_load, yaml_print) |
|
|
|
|
|
|
|
|
|
|
|
# Define valid tasks and modes |
|
|
|
# Define valid tasks and modes |
|
|
|
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' |
|
|
|
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' |
|
|
@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove |
|
|
|
return IterableSimpleNamespace(**cfg) |
|
|
|
return IterableSimpleNamespace(**cfg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_save_dir(args): |
|
|
|
|
|
|
|
"""Return save_dir as created from train/val/predict arguments.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(args, 'save_dir', None): |
|
|
|
|
|
|
|
save_dir = args.save_dir |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
from ultralytics.utils.files import increment_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
project = args.project or Path(SETTINGS['runs_dir']) / args.task |
|
|
|
|
|
|
|
name = args.name or f'{args.mode}' |
|
|
|
|
|
|
|
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Path(save_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_deprecation(custom): |
|
|
|
def _handle_deprecation(custom): |
|
|
|
"""Hardcoded function to handle deprecated config keys""" |
|
|
|
"""Hardcoded function to handle deprecated config keys.""" |
|
|
|
|
|
|
|
|
|
|
|
for key in custom.copy().keys(): |
|
|
|
for key in custom.copy().keys(): |
|
|
|
if key == 'hide_labels': |
|
|
|
if key == 'hide_labels': |
|
|
@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None): |
|
|
|
Args: |
|
|
|
Args: |
|
|
|
custom (dict): a dictionary of custom configuration options |
|
|
|
custom (dict): a dictionary of custom configuration options |
|
|
|
base (dict): a dictionary of base configuration options |
|
|
|
base (dict): a dictionary of base configuration options |
|
|
|
|
|
|
|
e (Error, optional): An optional error that is passed by the calling function. |
|
|
|
""" |
|
|
|
""" |
|
|
|
custom = _handle_deprecation(custom) |
|
|
|
custom = _handle_deprecation(custom) |
|
|
|
base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) |
|
|
|
base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) |
|
|
|