Add new `get_save_dir()` function (#4602)

pull/4600/head
Glenn Jocher 1 year ago committed by GitHub
parent 1121ef2409
commit 23b4f697c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 24
      ultralytics/cfg/__init__.py
  2. 4
      ultralytics/engine/model.py
  3. 11
      ultralytics/engine/predictor.py
  4. 6
      ultralytics/engine/results.py
  5. 16
      ultralytics/engine/trainer.py
  6. 13
      ultralytics/engine/validator.py

@ -8,9 +8,9 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Union
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML,
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
yaml_print)
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS,
SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
yaml_load, yaml_print)
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
return IterableSimpleNamespace(**cfg)
def get_save_dir(args):
"""Return save_dir as created from train/val/predict arguments."""
if getattr(args, 'save_dir', None):
save_dir = args.save_dir
else:
from ultralytics.utils.files import increment_path
project = args.project or Path(SETTINGS['runs_dir']) / args.task
name = args.name or f'{args.mode}'
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
return Path(save_dir)
def _handle_deprecation(custom):
"""Hardcoded function to handle deprecated config keys"""
"""Hardcoded function to handle deprecated config keys."""
for key in custom.copy().keys():
if key == 'hide_labels':
@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
Args:
custom (dict): a dictionary of custom configuration options
base (dict): a dictionary of base configuration options
e (Error, optional): An optional error that is passed by the calling function.
"""
custom = _handle_deprecation(custom)
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))

@ -5,7 +5,7 @@ import sys
from pathlib import Path
from typing import Union
from ultralytics.cfg import get_cfg
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.engine.exporter import Exporter
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
@ -239,7 +239,7 @@ class Model:
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
if 'project' in overrides or 'name' in overrides:
self.predictor.save_dir = self.predictor.get_save_dir()
self.predictor.save_dir = get_save_dir(self.predictor.args)
# Set prompts for SAM/FastSAM
if len and hasattr(self.predictor, 'set_prompts'):
self.predictor.set_prompts(prompts)

@ -34,11 +34,11 @@ import cv2
import numpy as np
import torch
from ultralytics.cfg import get_cfg
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data import load_inference_source
from ultralytics.data.augment import LetterBox, classify_transforms
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
from ultralytics.utils.checks import check_imgsz, check_imshow
from ultralytics.utils.files import increment_path
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
@ -84,7 +84,7 @@ class BasePredictor:
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.save_dir = self.get_save_dir()
self.save_dir = get_save_dir(self.args)
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False
@ -108,11 +108,6 @@ class BasePredictor:
self.txt_path = None
callbacks.add_integration_callbacks(self)
def get_save_dir(self):
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f'{self.args.mode}'
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
def preprocess(self, im):
"""Prepares input image before inference.

@ -323,14 +323,10 @@ class Results(SimpleClass):
if self.probs is not None:
LOGGER.warning('WARNING ⚠ Classify task do not support `save_crop`.')
return
if isinstance(save_dir, str):
save_dir = Path(save_dir)
if isinstance(file_name, str):
file_name = Path(file_name)
for d in self.boxes:
save_one_box(d.xyxy,
self.orig_img.copy(),
file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg',
BGR=True)
def tojson(self, normalize=False):

@ -23,15 +23,15 @@ from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
colorstr, emojis, yaml_save)
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, colorstr,
emojis, yaml_save)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run, increment_path
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
strip_optimizer)
@ -91,13 +91,7 @@ class BaseTrainer:
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f'{self.args.mode}'
if hasattr(self.args, 'save_dir'):
self.save_dir = Path(self.args.save_dir)
else:
self.save_dir = Path(
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
self.save_dir = get_save_dir(self.args)
self.wdir = self.save_dir / 'weights' # weights dir
if RANK in (-1, 0):
self.wdir.mkdir(parents=True, exist_ok=True) # make dir

@ -26,12 +26,11 @@ import numpy as np
import torch
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.files import increment_path
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
@ -71,7 +70,7 @@ class BaseValidator:
Args:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path): Directory to save results.
save_dir (Path, optional): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
@ -93,12 +92,8 @@ class BaseValidator:
self.jdict = None
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f'{self.args.mode}'
self.save_dir = save_dir or increment_path(Path(project) / name,
exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
self.save_dir = save_dir or get_save_dir(self.args)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001

Loading…
Cancel
Save