From 23b4f697c97a3f8d1dc2eae1712628f4a583e66d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Aug 2023 10:43:41 +0200 Subject: [PATCH] Add new `get_save_dir()` function (#4602) --- ultralytics/cfg/__init__.py | 24 ++++++++++++++++++++---- ultralytics/engine/model.py | 4 ++-- ultralytics/engine/predictor.py | 11 +++-------- ultralytics/engine/results.py | 6 +----- ultralytics/engine/trainer.py | 16 +++++----------- ultralytics/engine/validator.py | 13 ++++--------- 6 files changed, 35 insertions(+), 39 deletions(-) diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 19558ec700..9073b05eea 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -8,9 +8,9 @@ from pathlib import Path from types import SimpleNamespace from typing import Dict, List, Union -from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML, - IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load, - yaml_print) +from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, + SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, + yaml_load, yaml_print) # Define valid tasks and modes MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' @@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove return IterableSimpleNamespace(**cfg) +def get_save_dir(args): + """Return save_dir as created from train/val/predict arguments.""" + + if getattr(args, 'save_dir', None): + save_dir = args.save_dir + else: + from ultralytics.utils.files import increment_path + + project = args.project or Path(SETTINGS['runs_dir']) / args.task + name = args.name or f'{args.mode}' + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True) + + return Path(save_dir) + + def _handle_deprecation(custom): - """Hardcoded function to handle deprecated config keys""" + """Hardcoded function to handle deprecated config keys.""" for key in custom.copy().keys(): if key == 'hide_labels': @@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None): Args: custom (dict): a dictionary of custom configuration options base (dict): a dictionary of base configuration options + e (Error, optional): An optional error that is passed by the calling function. """ custom = _handle_deprecation(custom) base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index a2424d4b69..56bdb62fde 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -5,7 +5,7 @@ import sys from pathlib import Path from typing import Union -from ultralytics.cfg import get_cfg +from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.engine.exporter import Exporter from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load @@ -239,7 +239,7 @@ class Model: else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) if 'project' in overrides or 'name' in overrides: - self.predictor.save_dir = self.predictor.get_save_dir() + self.predictor.save_dir = get_save_dir(self.predictor.args) # Set prompts for SAM/FastSAM if len and hasattr(self.predictor, 'set_prompts'): self.predictor.set_prompts(prompts) diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index f3c6808dc7..010e2b970c 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -34,11 +34,11 @@ import cv2 import numpy as np import torch -from ultralytics.cfg import get_cfg +from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.data import load_inference_source from ultralytics.data.augment import LetterBox, classify_transforms from ultralytics.nn.autobackend import AutoBackend -from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops +from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops from ultralytics.utils.checks import check_imgsz, check_imshow from ultralytics.utils.files import increment_path from ultralytics.utils.torch_utils import select_device, smart_inference_mode @@ -84,7 +84,7 @@ class BasePredictor: overrides (dict, optional): Configuration overrides. Defaults to None. """ self.args = get_cfg(cfg, overrides) - self.save_dir = self.get_save_dir() + self.save_dir = get_save_dir(self.args) if self.args.conf is None: self.args.conf = 0.25 # default conf=0.25 self.done_warmup = False @@ -108,11 +108,6 @@ class BasePredictor: self.txt_path = None callbacks.add_integration_callbacks(self) - def get_save_dir(self): - project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task - name = self.args.name or f'{self.args.mode}' - return increment_path(Path(project) / name, exist_ok=self.args.exist_ok) - def preprocess(self, im): """Prepares input image before inference. diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index d636395f03..eee46909fd 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -323,14 +323,10 @@ class Results(SimpleClass): if self.probs is not None: LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.') return - if isinstance(save_dir, str): - save_dir = Path(save_dir) - if isinstance(file_name, str): - file_name = Path(file_name) for d in self.boxes: save_one_box(d.xyxy, self.orig_img.copy(), - file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg', + file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg', BGR=True) def tojson(self, normalize=False): diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 9b280e3896..84167df16b 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -23,15 +23,15 @@ from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm -from ultralytics.cfg import get_cfg +from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights -from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, - colorstr, emojis, yaml_save) +from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, colorstr, + emojis, yaml_save) from ultralytics.utils.autobatch import check_train_batch_size from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command -from ultralytics.utils.files import get_latest_run, increment_path +from ultralytics.utils.files import get_latest_run from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device, strip_optimizer) @@ -91,13 +91,7 @@ class BaseTrainer: init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) # Dirs - project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task - name = self.args.name or f'{self.args.mode}' - if hasattr(self.args, 'save_dir'): - self.save_dir = Path(self.args.save_dir) - else: - self.save_dir = Path( - increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)) + self.save_dir = get_save_dir(self.args) self.wdir = self.save_dir / 'weights' # weights dir if RANK in (-1, 0): self.wdir.mkdir(parents=True, exist_ok=True) # make dir diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index bf2496117a..e445368ca5 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -26,12 +26,11 @@ import numpy as np import torch from tqdm import tqdm -from ultralytics.cfg import get_cfg +from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.nn.autobackend import AutoBackend -from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis +from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, callbacks, colorstr, emojis from ultralytics.utils.checks import check_imgsz -from ultralytics.utils.files import increment_path from ultralytics.utils.ops import Profile from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode @@ -71,7 +70,7 @@ class BaseValidator: Args: dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. - save_dir (Path): Directory to save results. + save_dir (Path, optional): Directory to save results. pbar (tqdm.tqdm): Progress bar for displaying progress. args (SimpleNamespace): Configuration for the validator. _callbacks (dict): Dictionary to store various callback functions. @@ -93,12 +92,8 @@ class BaseValidator: self.jdict = None self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} - project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task - name = self.args.name or f'{self.args.mode}' - self.save_dir = save_dir or increment_path(Path(project) / name, - exist_ok=self.args.exist_ok if RANK in (-1, 0) else True) + self.save_dir = save_dir or get_save_dir(self.args) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) - if self.args.conf is None: self.args.conf = 0.001 # default conf=0.001