`ultralytics 8.0.12` - Hydra removal (#506)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Pronoy Mandal <lukex9442@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
pull/530/head v8.0.12
Glenn Jocher 2 years ago committed by GitHub
parent 6eec39162a
commit c5fccc3fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      .github/workflows/ci.yaml
  2. 1
      .gitignore
  3. 2
      docker/Dockerfile
  4. 14
      docs/predict.md
  5. 1
      requirements.txt
  6. 3
      setup.py
  7. 4
      tests/test_engine.py
  8. 2
      tests/test_python.py
  9. 2
      ultralytics/__init__.py
  10. 4
      ultralytics/hub/utils.py
  11. 12
      ultralytics/nn/tasks.py
  12. 156
      ultralytics/yolo/cli.py
  13. 231
      ultralytics/yolo/configs/__init__.py
  14. 108
      ultralytics/yolo/configs/default.yaml
  15. 77
      ultralytics/yolo/configs/hydra_patch.py
  16. 4
      ultralytics/yolo/data/build.py
  17. 51
      ultralytics/yolo/data/dataloaders/v5loader.py
  18. 34
      ultralytics/yolo/data/dataset.py
  19. 2
      ultralytics/yolo/data/datasets/coco.yaml
  20. 14
      ultralytics/yolo/engine/exporter.py
  21. 6
      ultralytics/yolo/engine/model.py
  22. 12
      ultralytics/yolo/engine/predictor.py
  23. 28
      ultralytics/yolo/engine/trainer.py
  24. 10
      ultralytics/yolo/engine/validator.py
  25. 10
      ultralytics/yolo/utils/__init__.py
  26. 2
      ultralytics/yolo/utils/dist.py
  27. 7
      ultralytics/yolo/utils/torch_utils.py
  28. 1
      ultralytics/yolo/v8/__init__.py
  29. 6
      ultralytics/yolo/v8/classify/predict.py
  30. 10
      ultralytics/yolo/v8/classify/train.py
  31. 7
      ultralytics/yolo/v8/classify/val.py
  32. 6
      ultralytics/yolo/v8/detect/predict.py
  33. 10
      ultralytics/yolo/v8/detect/train.py
  34. 8
      ultralytics/yolo/v8/detect/val.py
  35. 6
      ultralytics/yolo/v8/segment/predict.py
  36. 10
      ultralytics/yolo/v8/segment/train.py
  37. 6
      ultralytics/yolo/v8/segment/val.py

@ -85,6 +85,7 @@ jobs:
shell: bash # for Windows compatibility
run: |
yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32
yolo task=detect mode=train data=coco8.yaml model=yolov8n.pt epochs=1 imgsz=32
yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32
yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript
@ -92,6 +93,7 @@ jobs:
shell: bash # for Windows compatibility
run: |
yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32
yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.pt epochs=1 imgsz=32
yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32
yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript
@ -99,6 +101,7 @@ jobs:
shell: bash # for Windows compatibility
run: |
yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32
yolo task=classify mode=train data=mnist160 model=yolov8n-cls.pt epochs=1 imgsz=32
yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript

1
.gitignore vendored

@ -136,6 +136,7 @@ wandb/
.DS_Store
# Neural Network weights -----------------------------------------------------------------------------------------------
weights/
*.weights
*.pt
*.pb

@ -10,7 +10,7 @@ ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Aria
# Remove torch nightly and install torch stable
RUN rm -rf /opt/pytorch # remove 1.2GB dir
RUN pip uninstall -y torchtext torch torchvision
RUN pip uninstall -y torchtext pillow torch torchvision
RUN pip install --no-cache torch torchvision
# Install linux packages

@ -6,9 +6,9 @@ Inference or prediction of a task returns a list of `Results` objects. Alternati
inputs = [img, img] # list of np arrays
results = model(inputs) # List of Results objects
for result in results:
boxes = results.boxes # Boxes object for bbox outputs
masks = results.masks # Masks object for segmenation masks outputs
probs = results.probs # Class probabilities for classification outputs
boxes = result.boxes # Boxes object for bbox outputs
masks = result.masks # Masks object for segmenation masks outputs
probs = result.probs # Class probabilities for classification outputs
...
```
=== "Getting a Generator"
@ -16,9 +16,9 @@ Inference or prediction of a task returns a list of `Results` objects. Alternati
inputs = [img, img] # list of np arrays
results = model(inputs, stream=True) # Generator of Results objects
for result in results:
boxes = results.boxes # Boxes object for bbox outputs
masks = results.masks # Masks object for segmenation masks outputs
probs = results.probs # Class probabilities for classification outputs
boxes = result.boxes # Boxes object for bbox outputs
masks = result.masks # Masks object for segmenation masks outputs
probs = result.probs # Class probabilities for classification outputs
...
```
@ -69,4 +69,4 @@ results.masks.segments # bounding coordinates of masks, List[segment] * N
results.probs # cls prob, (num_class, )
```
Class reference documentation for `Results` module and its components can be found [here](reference/results.md)
Class reference documentation for `Results` module and its components can be found [here](reference/results.md)

@ -2,7 +2,6 @@
# Usage: pip install -r requirements.txt
# Base ----------------------------------------
hydra-core>=1.2.0
matplotlib>=3.2.2
numpy>=1.18.5
opencv-python>=4.1.1

@ -51,4 +51,5 @@ setup(
"Operating System :: MacOS", "Operating System :: Microsoft :: Windows"],
keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
entry_points={
'console_scripts': ['yolo = ultralytics.yolo.cli:entrypoint', 'ultralytics = ultralytics.yolo.cli:entrypoint']})
'console_scripts':
['yolo = ultralytics.yolo.configs:entrypoint', 'ultralytics = ultralytics.yolo.configs:entrypoint']})

@ -3,13 +3,13 @@
from pathlib import Path
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, SETTINGS
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS
from ultralytics.yolo.v8 import classify, detect, segment
CFG_DET = 'yolov8n.yaml'
CFG_SEG = 'yolov8n-seg.yaml'
CFG_CLS = 'squeezenet1_0'
CFG = get_config(DEFAULT_CONFIG)
CFG = get_config(DEFAULT_CFG_PATH)
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
SOURCE = ROOT / "assets"

@ -49,6 +49,8 @@ def test_predict_img():
assert len(output) == 1, "predict test failed"
output = model(source=[img, img], save=True, save_txt=True) # batch
assert len(output) == 2, "predict test failed"
output = model(source=[img, img], save=True, stream=True) # stream
assert len(list(output)) == 2, "predict test failed"
tens = torch.zeros(320, 640, 3)
output = model(tens.numpy())
assert len(output) == 1, "predict test failed"

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.11"
__version__ = "8.0.12"
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops

@ -7,7 +7,7 @@ import time
import requests
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
@ -143,7 +143,7 @@ def sync_analytics(cfg, all_keys=False, enabled=False):
if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
cfg = dict(cfg) # convert type from DictConfig to dict
if not all_keys:
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CONFIG_DICT.get(k, None)} # retain non-default values
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None)} # retain non-default values
cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data
# Send a request to the HUB API to sync analytics

@ -10,7 +10,7 @@ import torch.nn as nn
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync)
@ -113,7 +113,7 @@ class BaseModel(nn.Module):
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
Returns:
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
"""
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
@ -321,11 +321,11 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
ckpt.pt_path = weights # attach *.pt file path to model
if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.])
@ -359,11 +359,11 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
from ultralytics.yolo.utils.downloads import attempt_download
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model
if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.])

@ -1,156 +0,0 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import argparse
import re
import shutil
import sys
from pathlib import Path
from ultralytics import __version__, yolo
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, PREFIX, checks, print_settings, yaml_load
DIR = Path(__file__).parent
CLI_HELP_MSG = \
"""
YOLOv8 CLI Usage examples:
1. Install the ultralytics package:
pip install ultralytics
2. Train, Val, Predict and Export using 'yolo' commands:
yolo TASK MODE ARGS
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export]
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
For a full list of available ARGS see https://docs.ultralytics.com/config.
Train a detection model for 10 epochs with an initial learning_rate of 0.01
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
Validate a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
3. Run special commands:
yolo help
yolo checks
yolo version
yolo settings
yolo copy-config
Docs: https://docs.ultralytics.com/cli
Community: https://community.ultralytics.com
GitHub: https://github.com/ultralytics/ultralytics
"""
def cli(cfg):
"""
Run a specified task and mode with the given configuration.
Args:
cfg (DictConfig): Configuration for the task and mode.
"""
# LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
from ultralytics.yolo.configs import get_config
if cfg.cfg:
LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}")
cfg = get_config(cfg.cfg)
task, mode = cfg.task.lower(), cfg.mode.lower()
# Mapping from task to module
tasks = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
module = tasks.get(task)
if not module:
raise SyntaxError(f"yolo task={task} is invalid. Valid tasks are: {', '.join(tasks.keys())}\n{CLI_HELP_MSG}")
# Mapping from mode to function
modes = {"train": module.train, "val": module.val, "predict": module.predict, "export": yolo.engine.exporter.export}
func = modes.get(mode)
if not func:
raise SyntaxError(f"yolo mode={mode} is invalid. Valid modes are: {', '.join(modes.keys())}\n{CLI_HELP_MSG}")
func(cfg)
def entrypoint():
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package. It's a combination of argparse and hydra.
This function allows for:
- passing mandatory YOLO args as a list of strings
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
- specifying the mode, either 'train', 'val', 'test', or 'predict'
- running special modes like 'checks'
- passing overrides to the package's configuration
It uses the package's default config and initializes it using the passed overrides.
Then it calls the CLI function with the composed config
"""
if len(sys.argv) == 1: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
parser = argparse.ArgumentParser(description='YOLO parser')
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
args = parser.parse_args().args
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export'
special_modes = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,
'version': lambda: LOGGER.info(__version__),
'settings': print_settings,
'copy-config': copy_default_config}
overrides = [] # basic overrides, i.e. imgsz=320
defaults = yaml_load(DEFAULT_CONFIG)
for a in args:
if '=' in a:
overrides.append(a)
elif a in tasks:
overrides.append(f'task={a}')
elif a in modes:
overrides.append(f'mode={a}')
elif a in special_modes:
special_modes[a]()
return
elif a in defaults and defaults[a] is False:
overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show
elif a in defaults:
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
f"i.e. try '{a}={defaults[a]}'"
f"\n{CLI_HELP_MSG}")
else:
raise SyntaxError(
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
f"\n{CLI_HELP_MSG}")
from hydra import compose, initialize
with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"):
cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides)
cli(cfg)
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_config():
new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CONFIG, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n"
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")

@ -1,36 +1,221 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import argparse
import re
import shutil
import sys
from difflib import get_close_matches
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, Union
from omegaconf import DictConfig, OmegaConf
from ultralytics import __version__, yolo
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, PREFIX, checks, colorstr, print_settings, yaml_load
DIR = Path(__file__).parent
CLI_HELP_MSG = \
"""
YOLOv8 CLI Usage examples:
1. Install the ultralytics package:
pip install ultralytics
2. Train, Val, Predict and Export using 'yolo' commands:
yolo TASK MODE ARGS
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export]
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
For a full list of available ARGS see https://docs.ultralytics.com/config.
Train a detection model for 10 epochs with an initial learning_rate of 0.01
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
Validate a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
from ultralytics.yolo.configs.hydra_patch import check_config_mismatch
3. Run special commands:
yolo help
yolo checks
yolo version
yolo settings
yolo copy-config
def get_config(config: Union[str, Path, DictConfig], overrides: Union[str, Dict] = None):
Docs: https://docs.ultralytics.com/cli
Community: https://community.ultralytics.com
GitHub: https://github.com/ultralytics/ultralytics
"""
def cfg2dict(cfg):
"""
Convert a configuration object to a dictionary.
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Inputs:
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
Returns:
cfg (dict): Configuration object in dictionary format.
"""
if isinstance(cfg, (str, Path)):
cfg = yaml_load(cfg) # load dict
elif isinstance(cfg, SimpleNamespace):
cfg = vars(cfg) # convert to dict
return cfg
def get_config(config: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
"""
Load and merge configuration data from a file or dictionary.
Args:
config (str) or (Path) or (DictConfig): Configuration data in the form of a file name or a DictConfig object.
overrides (str) or(Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
config (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
Returns:
OmegaConf.Namespace: Training arguments namespace.
"""
if overrides is None:
overrides = {}
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
elif isinstance(config, Dict):
config = OmegaConf.create(config)
# override
if isinstance(overrides, str):
overrides = OmegaConf.load(overrides)
elif isinstance(overrides, Dict):
overrides = OmegaConf.create(overrides)
check_config_mismatch(dict(overrides).keys(), dict(config).keys())
return OmegaConf.merge(config, overrides)
(SimpleNamespace): Training arguments namespace.
"""
config = cfg2dict(config)
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
check_config_mismatch(config, overrides)
config = {**config, **overrides} # merge config and overrides dicts (prefer overrides)
# Return instance
return SimpleNamespace(**config)
def check_config_mismatch(base: Dict, custom: Dict):
"""
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
Inputs:
- custom (Dict): a dictionary of custom configuration options
- base (Dict): a dictionary of base configuration options
"""
base, custom = (set(x.keys()) for x in (base, custom))
mismatched = [x for x in custom if x not in base]
for option in mismatched:
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, base, 3, 0.6)}")
if mismatched:
sys.exit()
def entrypoint(debug=True):
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package.
This function allows for:
- passing mandatory YOLO args as a list of strings
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
- specifying the mode, either 'train', 'val', 'test', or 'predict'
- running special modes like 'checks'
- passing overrides to the package's configuration
It uses the package's default config and initializes it using the passed overrides.
Then it calls the CLI function with the composed config
"""
if debug:
args = ['train', 'predict', 'model=yolov8n.pt'] # for testing
else:
if len(sys.argv) == 1: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
parser = argparse.ArgumentParser(description='YOLO parser')
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
args = parser.parse_args().args
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export'
special_modes = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,
'version': lambda: LOGGER.info(__version__),
'settings': print_settings,
'copy-config': copy_default_config}
overrides = {} # basic overrides, i.e. imgsz=320
defaults = yaml_load(DEFAULT_CFG_PATH)
for a in args:
if '=' in a:
if a.startswith('cfg='): # custom.yaml passed
custom_config = Path(a.split('=')[-1])
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {custom_config}")
overrides = {k: v for k, v in yaml_load(custom_config).items() if k not in {'cfg'}}
else:
k, v = a.split('=')
try:
if k == 'device': # special DDP handling, i.e. device='0,1,2,3'
v = v.replace('[', '').replace(']', '') # handle device=[0,1,2,3]
v = v.replace(" ", "").replace('') # handle device=[0, 1, 2, 3]
v = v.replace('\\', '') # handle device=\'0,1,2,3\'
overrides[k] = v
else:
overrides[k] = eval(v) # convert strings to integers, floats, bools, etc.
except (NameError, SyntaxError):
overrides[k] = v
elif a in tasks:
overrides['task'] = a
elif a in modes:
overrides['mode'] = a
elif a in special_modes:
special_modes[a]()
return
elif a in defaults and defaults[a] is False:
overrides[a] = True # auto-True for default False args, i.e. 'yolo show' sets show=True
elif a in defaults:
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
f"i.e. try '{a}={defaults[a]}'"
f"\n{CLI_HELP_MSG}")
else:
raise SyntaxError(
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
f"\n{CLI_HELP_MSG}")
cfg = get_config(defaults, overrides) # create CFG instance
# Mapping from task to module
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
if not module:
raise SyntaxError(f"yolo task={cfg.task} is invalid. Valid tasks are: {', '.join(tasks)}\n{CLI_HELP_MSG}")
# Mapping from mode to function
func = {
"train": module.train,
"val": module.val,
"predict": module.predict,
"export": yolo.engine.exporter.export}.get(cfg.mode)
if not func:
raise SyntaxError(f"yolo mode={cfg.mode} is invalid. Valid modes are: {', '.join(modes)}\n{CLI_HELP_MSG}")
func(cfg)
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_config():
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
if __name__ == '__main__':
entrypoint()

@ -1,68 +1,68 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
# Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
data: null # i.e. coco128.yaml. Path to data file
epochs: 100 # number of epochs to train for
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
data: null # i.e. coco128.yaml. Path to data file
epochs: 100 # number of epochs to train for
patience: 50 # epochs to wait for no observable improvement for early stopping of training
batch: 16 # number of images per batch
imgsz: 640 # size of input images
save: True # save checkpoints
cache: False # True/ram, disk or False. Use cache for data loading
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
workers: 8 # number of worker threads for data loading
project: null # project name
name: null # experiment name
exist_ok: False # whether to overwrite existing experiment
pretrained: False # whether to use a pretrained model
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
verbose: False # whether to print verbose output
seed: 0 # random seed for reproducibility
deterministic: True # whether to enable deterministic mode
single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
rect: False # support rectangular training
cos_lr: False # use cosine learning rate scheduler
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
resume: False # resume training from last checkpoint
batch: 16 # number of images per batch
imgsz: 640 # size of input images
save: True # save checkpoints
cache: False # True/ram, disk or False. Use cache for data loading
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
workers: 8 # number of worker threads for data loading
project: null # project name
name: null # experiment name
exist_ok: False # whether to overwrite existing experiment
pretrained: False # whether to use a pretrained model
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
verbose: False # whether to print verbose output
seed: 0 # random seed for reproducibility
deterministic: True # whether to enable deterministic mode
single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
rect: False # support rectangular training
cos_lr: False # use cosine learning rate scheduler
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
resume: False # resume training from last checkpoint
# Segmentation
overlap_mask: True # masks should overlap during training
mask_ratio: 4 # mask downsample ratio
overlap_mask: True # masks should overlap during training
mask_ratio: 4 # mask downsample ratio
# Classification
dropout: 0.0 # use dropout regularization
# Val/Test settings ----------------------------------------------------------------------------------------------------
val: True # validate/test during training
save_json: False # save results to JSON file
save_hybrid: False # save hybrid version of labels (labels + additional predictions)
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou: 0.7 # intersection over union (IoU) threshold for NMS
max_det: 300 # maximum number of detections per image
half: False # use half precision (FP16)
dnn: False # use OpenCV DNN for ONNX inference
plots: True # show plots during training
val: True # validate/test during training
save_json: False # save results to JSON file
save_hybrid: False # save hybrid version of labels (labels + additional predictions)
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou: 0.7 # intersection over union (IoU) threshold for NMS
max_det: 300 # maximum number of detections per image
half: False # use half precision (FP16)
dnn: False # use OpenCV DNN for ONNX inference
plots: True # show plots during training
# Prediction settings --------------------------------------------------------------------------------------------------
source: null # source directory for images or videos
show: False # show results if possible
save_txt: False # save results as .txt file
save_conf: False # save results with confidence scores
save_crop: False # save cropped images with results
hide_labels: False # hide labels
hide_conf: False # hide confidence scores
vid_stride: 1 # video frame-rate stride
line_thickness: 3 # bounding box thickness (pixels)
visualize: False # visualize results
augment: False # apply data augmentation to images
agnostic_nms: False # class-agnostic NMS
retina_masks: False # use retina masks for object detection
source: null # source directory for images or videos
show: False # show results if possible
save_txt: False # save results as .txt file
save_conf: False # save results with confidence scores
save_crop: False # save cropped images with results
hide_labels: False # hide labels
hide_conf: False # hide confidence scores
vid_stride: 1 # video frame-rate stride
line_thickness: 3 # bounding box thickness (pixels)
visualize: False # visualize results
augment: False # apply data augmentation to images
agnostic_nms: False # class-agnostic NMS
retina_masks: False # use retina masks for object detection
# Export settings ------------------------------------------------------------------------------------------------------
format: torchscript # format to export to
format: torchscript # format to export to
keras: False # use Keras
optimize: False # TorchScript: optimize for mobile
int8: False # CoreML/TF INT8 quantization
@ -100,12 +100,8 @@ mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
# Hydra configs --------------------------------------------------------------------------------------------------------
cfg: null # for overriding defaults.yaml
hydra:
output_subdir: null # disable hydra directory creation
run:
dir: .
# Custom config.yaml ---------------------------------------------------------------------------------------------------
cfg: null # for overriding defaults.yaml
# Debug, do not modify -------------------------------------------------------------------------------------------------
v5loader: False # use legacy YOLOv5 dataloader

@ -1,77 +0,0 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import sys
from difflib import get_close_matches
from textwrap import dedent
import hydra
from hydra.errors import ConfigCompositionException
from omegaconf import OmegaConf, open_dict # noqa
from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException # noqa
from ultralytics.yolo.utils import LOGGER, colorstr
def override_config(overrides, cfg):
override_keys = [override.key_or_group for override in overrides]
check_config_mismatch(override_keys, cfg.keys())
for override in overrides:
if override.package is not None:
raise ConfigCompositionException(f"Override {override.input_line} looks like a config group"
f" override, but config group '{override.key_or_group}' does not exist.")
key = override.key_or_group
value = override.value()
try:
if override.is_delete():
config_val = OmegaConf.select(cfg, key, throw_on_missing=False)
if config_val is None:
raise ConfigCompositionException(f"Could not delete from config. '{override.key_or_group}'"
" does not exist.")
elif value is not None and value != config_val:
raise ConfigCompositionException("Could not delete from config. The value of"
f" '{override.key_or_group}' is {config_val} and not"
f" {value}.")
last_dot = key.rfind(".")
with open_dict(cfg):
if last_dot == -1:
del cfg[key]
else:
node = OmegaConf.select(cfg, key[:last_dot])
del node[key[last_dot + 1:]]
elif override.is_add():
if OmegaConf.select(cfg, key, throw_on_missing=False) is None or isinstance(value, (dict, list)):
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
else:
assert override.input_line is not None
raise ConfigCompositionException(
dedent(f"""\
Could not append to config. An item is already at '{override.key_or_group}'.
Either remove + prefix: '{override.input_line[1:]}'
Or add a second + to add or override '{override.key_or_group}': '+{override.input_line}'
"""))
elif override.is_force_add():
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
else:
try:
OmegaConf.update(cfg, key, value, merge=True)
except (ConfigAttributeError, ConfigKeyError) as ex:
raise ConfigCompositionException(f"Could not override '{override.key_or_group}'."
f"\nTo append to your config use +{override.input_line}") from ex
except OmegaConfBaseException as ex:
raise ConfigCompositionException(f"Error merging override {override.input_line}").with_traceback(
sys.exc_info()[2]) from ex
def check_config_mismatch(overrides, cfg):
mismatched = [option for option in overrides if option not in cfg and 'hydra.' not in option]
for option in mismatched:
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}")
if mismatched:
sys.exit()
hydra._internal.config_loader_impl.ConfigLoaderImpl._apply_overrides_to_config = override_config

@ -69,8 +69,8 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=cfg.get("cache", None),
single_cls=cfg.get("single_cls", False),
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),

@ -29,7 +29,8 @@ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm
from ultralytics.yolo.data.utils import check_dataset, unzip_file
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_kaggle
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
is_kaggle)
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
@ -493,7 +494,7 @@ class LoadImagesAndLabels(Dataset):
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == self.cache_version # matches current version
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError):
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
# Display cache
@ -579,16 +580,17 @@ class LoadImagesAndLabels(Dataset):
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache_images == 'disk':
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
b += self.ims[i].nbytes
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
pbar.close()
with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
results = pool.imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache_images == 'disk':
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
b += self.ims[i].nbytes
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
pbar.close()
def check_cache_ram(self, safety_margin=0.1, prefix=''):
# Check image caching requirements vs available memory
@ -612,11 +614,10 @@ class LoadImagesAndLabels(Dataset):
x = {} # dict
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning {path.parent / path.stem}..."
with Pool(NUM_THREADS) as pool:
pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
desc=desc,
total=len(self.im_files),
bar_format=TQDM_BAR_FORMAT)
total = len(self.im_files)
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
@ -627,8 +628,8 @@ class LoadImagesAndLabels(Dataset):
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
pbar.close()
if msgs:
LOGGER.info('\n'.join(msgs))
if nf == 0:
@ -637,12 +638,12 @@ class LoadImagesAndLabels(Dataset):
x['results'] = nf, nm, ne, nc, len(self.im_files)
x['msgs'] = msgs # warnings
x['version'] = self.cache_version # cache version
try:
np.save(path, x) # save cache for next time
if is_dir_writeable(path.parent):
np.save(str(path), x) # save cache for next time
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
LOGGER.info(f'{prefix}New cache created: {path}')
except Exception as e:
LOGGER.warning(f'{prefix}WARNING ⚠ Cache directory {path.parent} is not writeable: {e}') # not writeable
else:
LOGGER.warning(f'{prefix}WARNING ⚠ Cache directory {path.parent} is not writeable') # not writeable
return x
def __len__(self):
@ -1148,8 +1149,10 @@ class HUBDatasetStats():
continue
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
desc = f'{split} images'
for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
pass
total = dataset.n
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
pass
print(f'Done. All images saved to {self.im_dir}')
return self.im_dir

@ -1,13 +1,13 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from itertools import repeat
from multiprocessing.pool import Pool
from multiprocessing.pool import Pool, ThreadPool
from pathlib import Path
import torchvision
from tqdm import tqdm
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
from .augment import *
from .base import BaseDataset
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
@ -50,14 +50,12 @@ class YOLODataset(BaseDataset):
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
with Pool(NUM_THREADS) as pool:
pbar = tqdm(
pool.imap(verify_image_label,
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
desc=desc,
total=len(self.im_files),
bar_format=TQDM_BAR_FORMAT,
)
total = len(self.im_files)
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
results = pool.imap(func=verify_image_label,
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
repeat(self.use_keypoints)))
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
@ -73,13 +71,12 @@ class YOLODataset(BaseDataset):
segments=segments,
keypoints=keypoint,
normalized=True,
bbox_format="xywh",
))
bbox_format="xywh"))
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
@ -89,13 +86,12 @@ class YOLODataset(BaseDataset):
x["msgs"] = msgs # warnings
x["version"] = self.cache_version # cache version
self.im_files = [lb["im_file"] for lb in x["labels"]]
try:
np.save(path, x) # save cache for next time
if is_dir_writeable(path.parent):
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{self.prefix}New cache created: {path}")
except Exception as e:
LOGGER.warning(
f"{self.prefix}WARNING ⚠ Cache directory {path.parent} is not writeable: {e}") # not writeable
else:
LOGGER.warning(f"{self.prefix}WARNING ⚠ Cache directory {path.parent} is not writeable") # not writeable
return x
def get_labels(self):
@ -105,7 +101,7 @@ class YOLODataset(BaseDataset):
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
assert cache["version"] == self.cache_version # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError):
except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache

@ -99,7 +99,7 @@ names:
# Download script/URL (optional)
download: |
from ultralytics.yoloutils.downloads import download
from ultralytics.yolo.utils.downloads import download
from pathlib import Path
# Download labels

@ -60,7 +60,6 @@ from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import hydra
import numpy as np
import pandas as pd
import torch
@ -71,7 +70,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import check_dataset
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
@ -123,11 +122,11 @@ class Exporter:
A class for exporting a model.
Attributes:
args (OmegaConf): Configuration for the exporter.
args (SimpleNamespace): Configuration for the exporter.
save_dir (Path): Directory to save results.
"""
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
def __init__(self, config=DEFAULT_CFG, overrides=None):
"""
Initializes the Exporter class.
@ -135,8 +134,6 @@ class Exporter:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
if overrides is None:
overrides = {}
self.args = get_config(config, overrides)
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self)
@ -799,8 +796,7 @@ class Exporter:
callback(self)
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def export(cfg):
def export(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n.yaml"
cfg.format = cfg.format or "torchscript"
@ -818,7 +814,7 @@ def export(cfg):
from ultralytics import YOLO
model = YOLO(cfg.model)
model.export(**cfg)
model.export(**vars(cfg))
if __name__ == "__main__":

@ -6,7 +6,7 @@ from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
@ -151,7 +151,7 @@ class YOLO:
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides["mode"] = "val"
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
args.data = data or args.data
args.task = self.task
@ -169,7 +169,7 @@ class YOLO:
overrides = self.overrides.copy()
overrides.update(kwargs)
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
args.task = self.task
print(args)

@ -36,7 +36,7 @@ from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -49,7 +49,7 @@ class BasePredictor:
A base class for creating predictors.
Attributes:
args (OmegaConf): Configuration for the predictor.
args (SimpleNamespace): Configuration for the predictor.
save_dir (Path): Directory to save results.
done_setup (bool): Whether the predictor has finished setup.
model (nn.Module): Model used for prediction.
@ -62,7 +62,7 @@ class BasePredictor:
data_path (str): Path to data.
"""
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
"""
Initializes the BasePredictor class.
@ -70,8 +70,6 @@ class BasePredictor:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
if overrides is None:
overrides = {}
self.args = get_config(config, overrides)
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
@ -157,7 +155,7 @@ class BasePredictor:
if stream:
return self.stream_inference(source, model, verbose)
else:
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one
return list(self.stream_inference(source, model, verbose)) # merge list of Result into one
def predict_cli(self):
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
@ -211,7 +209,7 @@ class BasePredictor:
if self.args.save:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
yield results
yield from results
# Print time (inference-only)
if verbose:

@ -15,8 +15,6 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from omegaconf import OmegaConf # noqa
from omegaconf import open_dict
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
@ -27,7 +25,7 @@ from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
@ -43,7 +41,7 @@ class BaseTrainer:
A base class for creating trainers.
Attributes:
args (OmegaConf): Configuration for the trainer.
args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
console (logging.Logger): Logger instance.
validator (BaseValidator): Validator instance.
@ -73,7 +71,7 @@ class BaseTrainer:
csv (Path): Path to results CSV file.
"""
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
"""
Initializes the BaseTrainer class.
@ -81,8 +79,6 @@ class BaseTrainer:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
if overrides is None:
overrides = {}
self.args = get_config(config, overrides)
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
self.check_resume()
@ -95,23 +91,23 @@ class BaseTrainer:
# Dirs
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
self.save_dir = Path(
self.args.get(
"save_dir",
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)))
if hasattr(self.args, 'save_dir'):
self.save_dir = Path(self.args.save_dir)
else:
self.save_dir = Path(
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))
self.wdir = self.save_dir / 'weights' # weights dir
if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
with open_dict(self.args):
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch
self.epochs = self.args.epochs
self.start_epoch = 0
if RANK == -1:
print_args(dict(self.args))
print_args(vars(self.args))
# Device
self.amp = self.device.type != 'cpu'
@ -373,7 +369,7 @@ class BaseTrainer:
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': self.args,
'train_args': vars(self.args), # save as dict
'date': datetime.now().isoformat(),
'version': __version__}

@ -5,12 +5,12 @@ from collections import defaultdict
from pathlib import Path
import torch
from omegaconf import OmegaConf # noqa
from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile
@ -27,7 +27,7 @@ class BaseValidator:
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
logger (logging.Logger): Logger to use for validation.
args (OmegaConf): Configuration for the validator.
args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
@ -47,12 +47,12 @@ class BaseValidator:
save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
logger (logging.Logger): Logger to log messages.
args (OmegaConf): Configuration for the validator.
args (SimpleNamespace): Configuration for the validator.
"""
self.dataloader = dataloader
self.pbar = pbar
self.logger = logger or LOGGER
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
self.args = args or get_config(DEFAULT_CFG_PATH)
self.model = None
self.data = None
self.device = None

@ -8,6 +8,7 @@ import platform
import sys
import tempfile
import threading
import types
import uuid
from pathlib import Path
from typing import Union
@ -22,7 +23,7 @@ import yaml
# Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO
DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml"
DEFAULT_CFG_PATH = ROOT / "yolo/configs/default.yaml"
RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
@ -73,9 +74,10 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
# Default config dictionary
with open(DEFAULT_CONFIG, errors='ignore') as f:
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
DEFAULT_CFG_DICT = yaml.safe_load(f)
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = types.SimpleNamespace(**DEFAULT_CFG_DICT)
def is_colab():

@ -28,7 +28,7 @@ def generate_ddp_file(trainer):
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__":
content = f'''config = {vars(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__}
trainer = {trainer.__class__.__name__}(config=config)

@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
from ultralytics.yolo.utils.checks import git_describe
from .checks import check_version
@ -288,7 +288,7 @@ def strip_optimizer(f='best.pt', s=''):
None
"""
x = torch.load(f, map_location=torch.device('cpu'))
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
@ -297,7 +297,8 @@ def strip_optimizer(f='best.pt', s=''):
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI)
from ultralytics.yolo.v8 import classify, detect, segment
__all__ = ["classify", "segment", "detect"]

@ -1,11 +1,10 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory
from ultralytics.yolo.utils.plotting import Annotator
@ -64,8 +63,7 @@ class ClassificationPredictor(BasePredictor):
return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
def predict(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg"

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch
import torchvision
@ -8,13 +7,13 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CONFIG
from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.torch_utils import strip_optimizer
class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
def __init__(self, config=DEFAULT_CFG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "classify"
@ -136,8 +135,7 @@ class ClassificationTrainer(BaseTrainer):
# self.run_callbacks('on_fit_epoch_end')
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg):
def train(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
@ -152,7 +150,7 @@ def train(cfg):
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
model.train(**vars(cfg))
if __name__ == "__main__":

@ -1,10 +1,8 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CONFIG
from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.metrics import ClassifyMetrics
@ -46,8 +44,7 @@ class ClassificationValidator(BaseValidator):
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def val(cfg):
def val(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "imagenette160"
validator = ClassificationValidator(args=cfg)

@ -1,11 +1,10 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
@ -81,8 +80,7 @@ class DetectionPredictor(BasePredictor):
return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
def predict(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg"

@ -2,7 +2,6 @@
from copy import copy
import hydra
import torch
import torch.nn as nn
@ -11,7 +10,7 @@ from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr
from ultralytics.yolo.utils.loss import BboxLoss
from ultralytics.yolo.utils.ops import xywh2xyxy
from ultralytics.yolo.utils.plotting import plot_images, plot_results
@ -30,7 +29,7 @@ class DetectionTrainer(BaseTrainer):
imgsz=self.args.imgsz,
batch_size=batch_size,
stride=gs,
hyp=dict(self.args),
hyp=vars(self.args),
augment=mode == "train",
cache=self.args.cache,
pad=0 if mode == "train" else 0.5,
@ -195,8 +194,7 @@ class Loss:
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg):
def train(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else ''
@ -204,7 +202,7 @@ def train(cfg):
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
model.train(**vars(cfg))
if __name__ == "__main__":

@ -3,14 +3,13 @@
import os
from pathlib import Path
import hydra
import numpy as np
import torch
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr, ops, yaml_load
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -168,7 +167,7 @@ class DetectionValidator(BaseValidator):
imgsz=self.args.imgsz,
batch_size=batch_size,
stride=gs,
hyp=dict(self.args),
hyp=vars(self.args),
cache=False,
pad=0.5,
rect=True,
@ -232,8 +231,7 @@ class DetectionValidator(BaseValidator):
return stats
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def val(cfg):
def val(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml"
validator = DetectionValidator(args=cfg)

@ -1,10 +1,9 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils.plotting import colors, save_one_box
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
@ -98,8 +97,7 @@ class SegmentationPredictor(DetectionPredictor):
return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
def predict(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n-seg.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg"

@ -2,13 +2,12 @@
from copy import copy
import hydra
import torch
import torch.nn.functional as F
from ultralytics.nn.tasks import SegmentationModel
from ultralytics.yolo import v8
from ultralytics.yolo.utils import DEFAULT_CONFIG
from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import make_anchors
@ -19,7 +18,7 @@ from ultralytics.yolo.v8.detect.train import Loss
# BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
def __init__(self, config=DEFAULT_CFG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "segment"
@ -141,8 +140,7 @@ class SegLoss(Loss):
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg):
def train(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n-seg.pt"
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else ''
@ -150,7 +148,7 @@ def train(cfg):
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
model.train(**vars(cfg))
if __name__ == "__main__":

@ -4,12 +4,11 @@ import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
import hydra
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops
from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -243,8 +242,7 @@ class SegmentationValidator(DetectionValidator):
return stats
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def val(cfg):
def val(cfg=DEFAULT_CFG):
cfg.data = cfg.data or "coco128-seg.yaml"
validator = SegmentationValidator(args=cfg)
validator(model=cfg.model)

Loading…
Cancel
Save