Add Hyperparameter evolution `Tuner()` class (#4599)

pull/4628/head^2
Glenn Jocher 1 year ago committed by GitHub
parent 7e99804263
commit 4bd62a299c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 96
      docs/guides/hyperparameter-tuning.md
  2. 1
      docs/guides/index.md
  3. 35
      docs/integrations/ray-tune.md
  4. 14
      docs/reference/engine/tuner.md
  5. 82
      docs/usage/cfg.md
  6. 2
      mkdocs.yml
  7. 2
      setup.py
  8. 14
      tests/test_cuda.py
  9. 2
      ultralytics/__init__.py
  10. 4
      ultralytics/cfg/__init__.py
  11. 33
      ultralytics/engine/model.py
  12. 183
      ultralytics/engine/tuner.py
  13. 1
      ultralytics/utils/metrics.py
  14. 13
      ultralytics/utils/ops.py
  15. 16
      ultralytics/utils/tuner.py

@ -0,0 +1,96 @@
---
comments: true
description: Dive into hyperparameter tuning in Ultralytics YOLO models. Learn how to optimize performance using the Tuner class and genetic evolution.
keywords: Ultralytics, YOLO, Hyperparameter Tuning, Tuner Class, Genetic Evolution, Optimization
---
# Ultralytics YOLO Hyperparameter Tuning Guide
## Introduction
Hyperparameter tuning is not just a one-time set-up but an iterative process aimed at optimizing the machine learning model's performance metrics, such as accuracy, precision, and recall. In the context of Ultralytics YOLO, these hyperparameters could range from learning rate to architectural details, such as the number of layers or types of activation functions used.
### What are Hyperparameters?
Hyperparameters are high-level, structural settings for the algorithm. They are set prior to the training phase and remain constant during it. Here are some commonly tuned hyperparameters in Ultralytics YOLO:
- **Learning Rate**: Determines the step size at each iteration while moving towards a minimum in the loss function.
- **Batch Size**: Number of training samples utilized in one iteration.
- **Number of Epochs**: An epoch is one complete forward and backward pass of all the training examples.
- **Architecture Specifics**: Such as anchor box sizes, number of layers, types of activation functions, etc.
<p align="center">
<img width="1000" src="https://user-images.githubusercontent.com/26833433/263858934-4f109a2f-82d9-4d08-8bd6-6fd1ff520bcd.png" alt="Hyperparameter Tuning Visual">
</p>
For a full list of augmentation hyperparameters used in YOLOv8 please refer to https://docs.ultralytics.com/usage/cfg/#augmentation.
### Genetic Evolution and Mutation
Ultralytics YOLO uses genetic algorithms to optimize hyperparameters. Genetic algorithms are inspired by the mechanism of natural selection and genetics.
- **Mutation**: In the context of Ultralytics YOLO, mutation helps in locally searching the hyperparameter space by applying small, random changes to existing hyperparameters, producing new candidates for evaluation.
- **Crossover**: Although crossover is a popular genetic algorithm technique, it is not currently used in Ultralytics YOLO for hyperparameter tuning. The focus is mainly on mutation for generating new hyperparameter sets.
## Preparing for Hyperparameter Tuning
Before you begin the tuning process, it's important to:
1. **Identify the Metrics**: Determine the metrics you will use to evaluate the model's performance. This could be AP50, F1-score, or others.
2. **Set the Tuning Budget**: Define how much computational resources you're willing to allocate. Hyperparameter tuning can be computationally intensive.
## Steps Involved
### Initialize Hyperparameters
Start with a reasonable set of initial hyperparameters. This could either be the default hyperparameters set by Ultralytics YOLO or something based on your domain knowledge or previous experiments.
### Mutate Hyperparameters
Use the `_mutate` method to produce a new set of hyperparameters based on the existing set.
### Train Model
Training is performed using the mutated set of hyperparameters. The training performance is then assessed.
### Evaluate Model
Use metrics like AP50, F1-score, or custom metrics to evaluate the model's performance.
### Log Results
It's crucial to log both the performance metrics and the corresponding hyperparameters for future reference.
### Repeat
The process is repeated until either the set number of iterations is reached or the performance metric is satisfactory.
## Usage Example
Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyperparameter tuning:
!!! example ""
=== "Python"
```python
from ultralytics import YOLO
# Initialize the YOLO model
model = YOLO('yolov8n.pt')
# Perform hyperparameter tuning
model.tune(data='coco8.yaml', imgsz=640, epochs=30, iterations=300)
```
## Conclusion
The hyperparameter tuning process in Ultralytics YOLO is simplified yet powerful, thanks to its genetic algorithm-based approach focused on mutation. Following the steps outlined in this guide will assist you in systematically tuning your model to achieve better performance.
### Further Reading
1. [Hyperparameter Optimization in Wikipedia](https://en.wikipedia.org/wiki/Hyperparameter_optimization)
2. [YOLOv5 Hyperparameter Evolution Guide](https://docs.ultralytics.com/yolov5/tutorials/hyperparameter_evolution/)
3. [Efficient Hyperparameter Tuning with Ray Tune and YOLOv8](https://docs.ultralytics.com/integrations/ray-tune/)
For deeper insights, you can explore the `Tuner` class source code and accompanying documentation. Should you have any questions, feature requests, or need further assistance, feel free to reach out to our support team.

@ -15,5 +15,6 @@ Whether you're a beginner or an expert in deep learning, our tutorials offer val
Here's a compilation of in-depth guides to help you master different aspects of Ultralytics YOLO.
* [K-Fold Cross Validation](kfold-cross-validation.md) 🚀 NEW: Learn how to improve model generalization using K-Fold cross-validation technique.
* [Hyperparameter Tuning](hyperparameter-tuning.md) 🚀 NEW: Discover how to optimize your YOLO models by fine-tuning hyperparameters using the Tuner class and genetic evolution algorithms.
Note: More guides about training, exporting, predicting, and deploying with Ultralytics YOLO are coming soon. Stay tuned!

@ -30,27 +30,31 @@ To install the required packages, run:
!!! tip "Installation"
```bash
# Install and update Ultralytics and Ray Tune packages
pip install -U ultralytics "ray[tune]"
=== "CLI"
# Optionally install W&B for logging
pip install wandb
```
```bash
# Install and update Ultralytics and Ray Tune packages
pip install -U ultralytics "ray[tune]"
# Optionally install W&B for logging
pip install wandb
```
## Usage
!!! example "Usage"
```python
from ultralytics import YOLO
=== "Python"
# Load a YOLOv8n model
model = YOLO("yolov8n.pt")
```python
from ultralytics import YOLO
# Start tuning hyperparameters for YOLOv8n training on the COCO128 dataset
result_grid = model.tune(data="coco128.yaml")
```
# Load a YOLOv8n model
model = YOLO('yolov8n.pt')
# Start tuning hyperparameters for YOLOv8n training on the COCO8 dataset
result_grid = model.tune(data='coco8.yaml', use_ray=True)
```
## `tune()` Method Parameters
@ -62,7 +66,7 @@ The `tune()` method in YOLOv8 provides an easy-to-use interface for hyperparamet
| `space` | `dict, optional` | A dictionary defining the hyperparameter search space for Ray Tune. Each key corresponds to a hyperparameter name, and the value specifies the range of values to explore during tuning. If not provided, YOLOv8 uses a default search space with various hyperparameters. | |
| `grace_period` | `int, optional` | The grace period in epochs for the [ASHA scheduler](https://docs.ray.io/en/latest/tune/api/schedulers.html) in Ray Tune. The scheduler will not terminate any trial before this number of epochs, allowing the model to have some minimum training before making a decision on early stopping. | 10 |
| `gpu_per_trial` | `int, optional` | The number of GPUs to allocate per trial during tuning. This helps manage GPU usage, particularly in multi-GPU environments. If not provided, the tuner will use all available GPUs. | None |
| `max_samples` | `int, optional` | The maximum number of trials to run during tuning. This parameter helps control the total number of hyperparameter combinations tested, ensuring the tuning process does not run indefinitely. | 10 |
| `iterations` | `int, optional` | The maximum number of trials to run during tuning. This parameter helps control the total number of hyperparameter combinations tested, ensuring the tuning process does not run indefinitely. | 10 |
| `**train_args` | `dict, optional` | Additional arguments to pass to the `train()` method during tuning. These arguments can include settings like the number of training epochs, batch size, and other training-specific configurations. | {} |
By customizing these parameters, you can fine-tune the hyperparameter optimization process to suit your specific needs and available computational resources.
@ -110,7 +114,8 @@ In this example, we demonstrate how to use a custom search space for hyperparame
# Run Ray Tune on the model
result_grid = model.tune(data="coco128.yaml",
space={"lr0": tune.uniform(1e-5, 1e-1)},
epochs=50)
epochs=50,
use_ray=True)
```
In the code snippet above, we create a YOLO model with the "yolov8n.pt" pretrained weights. Then, we call the `tune()` method, specifying the dataset configuration with "coco128.yaml". We provide a custom search space for the initial learning rate `lr0` using a dictionary with the key "lr0" and the value `tune.uniform(1e-5, 1e-1)`. Finally, we pass additional training arguments, such as the number of epochs directly to the tune method as `epochs=50`.

@ -0,0 +1,14 @@
---
description: Explore the Ultralytics Tuner, a powerful tool designed for hyperparameter tuning of YOLO models to optimize performance across various tasks like object detection, image classification, and more.
keywords: Ultralytics, Tuner, YOLO, hyperparameter tuning, optimization, object detection, image classification, instance segmentation, pose estimation, multi-object tracking
---
# Reference for `ultralytics/engine/tuner.py`
!!! note
Full source code for this file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/tuner.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/tuner.py). Help us fix any issues you see by submitting a [Pull Request](https://docs.ultralytics.com/help/contributing/) 🛠. Thank you 🙏!
---
## ::: ultralytics.engine.tuner.Tuner
<br><br>

@ -4,11 +4,9 @@ description: Master YOLOv8 settings and hyperparameters for improved model perfo
keywords: YOLOv8, settings, hyperparameters, YOLO CLI commands, YOLO tasks, YOLO modes, Ultralytics documentation, model optimization, YOLOv8 training
---
YOLO settings and hyperparameters play a critical role in the model's performance, speed, and accuracy. These settings
and hyperparameters can affect the model's behavior at various stages of the model development process, including
training, validation, and prediction.
YOLO settings and hyperparameters play a critical role in the model's performance, speed, and accuracy. These settings and hyperparameters can affect the model's behavior at various stages of the model development process, including training, validation, and prediction.
YOLOv8 'yolo' CLI commands use the following syntax:
YOLOv8 `yolo` CLI commands use the following syntax:
!!! example ""
@ -32,18 +30,15 @@ YOLOv8 'yolo' CLI commands use the following syntax:
Where:
- `TASK` (optional) is one of `[detect, segment, classify, pose]`. If it is not passed explicitly YOLOv8 will try to
guess
the `TASK` from the model type.
- `TASK` (optional) is one of `[detect, segment, classify, pose]`. If it is not passed explicitly YOLOv8 will try to guess the `TASK` from the model type.
- `MODE` (required) is one of `[train, val, predict, export, track, benchmark]`
- `ARGS` (optional) are any number of custom `arg=value` pairs like `imgsz=320` that override defaults.
For a full list of available `ARGS` see the [Configuration](cfg.md) page and `defaults.yaml`
GitHub [source](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml).
For a full list of available `ARGS` see the [Configuration](cfg.md) page and `defaults.yaml` GitHub [source](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml).
#### Tasks
YOLO models can be used for a variety of tasks, including detection, segmentation, classification and pose. These tasks
differ in the type of output they produce and the specific problem they are designed to solve.
YOLO models can be used for a variety of tasks, including detection, segmentation, classification and pose. These tasks differ in the type of output they produce and the specific problem they are designed to solve.
**Detect**: For identifying and localizing objects or regions of interest in an image or video.
**Segment**: For dividing an image or video into regions or pixels that correspond to different objects or classes.
@ -58,8 +53,7 @@ differ in the type of output they produce and the specific problem they are desi
#### Modes
YOLO models can be used in different modes depending on the specific problem you are trying to solve. These modes
include:
YOLO models can be used in different modes depending on the specific problem you are trying to solve. These modes include:
**Train**: For training a YOLOv8 model on a custom dataset.
**Val**: For validating a YOLOv8 model after it has been trained.
@ -202,50 +196,34 @@ Export settings for YOLO models encompass configurations and options related to
## Augmentation
Augmentation settings for YOLO models refer to the various transformations and modifications
applied to the training data to increase the diversity and size of the dataset. These settings can affect the model's
performance, speed, and accuracy. Some common YOLO augmentation settings include the type and intensity of the
transformations applied (e.g. random flips, rotations, cropping, color changes), the probability with which each
transformation is applied, and the presence of additional features such as masks or multiple labels per box. Other
factors that may affect the augmentation process include the size and composition of the original dataset and the
specific task the model is being used for. It is important to carefully tune and experiment with these settings to
ensure that the augmented dataset is diverse and representative enough to train a high-performing model.
| Key | Value | Description |
|---------------|-------|-------------------------------------------------|
| `hsv_h` | 0.015 | image HSV-Hue augmentation (fraction) |
| `hsv_s` | 0.7 | image HSV-Saturation augmentation (fraction) |
| `hsv_v` | 0.4 | image HSV-Value augmentation (fraction) |
| `degrees` | 0.0 | image rotation (+/- deg) |
| `translate` | 0.1 | image translation (+/- fraction) |
| `scale` | 0.5 | image scale (+/- gain) |
| `shear` | 0.0 | image shear (+/- deg) |
| `perspective` | 0.0 | image perspective (+/- fraction), range 0-0.001 |
| `flipud` | 0.0 | image flip up-down (probability) |
| `fliplr` | 0.5 | image flip left-right (probability) |
| `mosaic` | 1.0 | image mosaic (probability) |
| `mixup` | 0.0 | image mixup (probability) |
| `copy_paste` | 0.0 | segment copy-paste (probability) |
Augmentation settings for YOLO models refer to the various transformations and modifications applied to the training data to increase the diversity and size of the dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO augmentation settings include the type and intensity of the transformations applied (e.g. random flips, rotations, cropping, color changes), the probability with which each transformation is applied, and the presence of additional features such as masks or multiple labels per box. Other factors that may affect the augmentation process include the size and composition of the original dataset and the specific task the model is being used for. It is important to carefully tune and experiment with these settings to ensure that the augmented dataset is diverse and representative enough to train a high-performing model.
| Key | Value | Description |
|---------------|---------|-------------------------------------------------|
| `hsv_h` | `0.015` | image HSV-Hue augmentation (fraction) |
| `hsv_s` | `0.7` | image HSV-Saturation augmentation (fraction) |
| `hsv_v` | `0.4` | image HSV-Value augmentation (fraction) |
| `degrees` | `0.0` | image rotation (+/- deg) |
| `translate` | `0.1` | image translation (+/- fraction) |
| `scale` | `0.5` | image scale (+/- gain) |
| `shear` | `0.0` | image shear (+/- deg) |
| `perspective` | `0.0` | image perspective (+/- fraction), range 0-0.001 |
| `flipud` | `0.0` | image flip up-down (probability) |
| `fliplr` | `0.5` | image flip left-right (probability) |
| `mosaic` | `1.0` | image mosaic (probability) |
| `mixup` | `0.0` | image mixup (probability) |
| `copy_paste` | `0.0` | segment copy-paste (probability) |
## Logging, checkpoints, plotting and file management
Logging, checkpoints, plotting, and file management are important considerations when training a YOLO model.
- Logging: It is often helpful to log various metrics and statistics during training to track the model's progress and
diagnose any issues that may arise. This can be done using a logging library such as TensorBoard or by writing log
messages to a file.
- Checkpoints: It is a good practice to save checkpoints of the model at regular intervals during training. This allows
you to resume training from a previous point if the training process is interrupted or if you want to experiment with
different training configurations.
- Plotting: Visualizing the model's performance and training progress can be helpful for understanding how the model is
behaving and identifying potential issues. This can be done using a plotting library such as matplotlib or by
generating plots using a logging library such as TensorBoard.
- File management: Managing the various files generated during the training process, such as model checkpoints, log
files, and plots, can be challenging. It is important to have a clear and organized file structure to keep track of
these files and make it easy to access and analyze them as needed.
Effective logging, checkpointing, plotting, and file management can help you keep track of the model's progress and make
it easier to debug and optimize the training process.
- Logging: It is often helpful to log various metrics and statistics during training to track the model's progress and diagnose any issues that may arise. This can be done using a logging library such as TensorBoard or by writing log messages to a file.
- Checkpoints: It is a good practice to save checkpoints of the model at regular intervals during training. This allows you to resume training from a previous point if the training process is interrupted or if you want to experiment with different training configurations.
- Plotting: Visualizing the model's performance and training progress can be helpful for understanding how the model is behaving and identifying potential issues. This can be done using a plotting library such as matplotlib or by generating plots using a logging library such as TensorBoard.
- File management: Managing the various files generated during the training process, such as model checkpoints, log files, and plots, can be challenging. It is important to have a clear and organized file structure to keep track of these files and make it easy to access and analyze them as needed.
Effective logging, checkpointing, plotting, and file management can help you keep track of the model's progress and make it easier to debug and optimize the training process.
| Key | Value | Description |
|------------|----------|------------------------------------------------------------------------------------------------|

@ -215,6 +215,7 @@ nav:
- Guides:
- guides/index.md
- K-Fold Cross Validation: guides/kfold-cross-validation.md
- Hyperparameter Tuning: guides/hyperparameter-tuning.md
- Integrations:
- integrations/index.md
- OpenVINO: integrations/openvino.md
@ -279,6 +280,7 @@ nav:
- predictor: reference/engine/predictor.md
- results: reference/engine/results.md
- trainer: reference/engine/trainer.md
- tuner: reference/engine/tuner.md
- validator: reference/engine/validator.md
- hub:
- __init__: reference/hub/__init__.md

@ -47,7 +47,7 @@ setup(
'mkdocs-material',
'mkdocstrings[python]',
'mkdocs-redirects', # for 301 redirects
'mkdocs-ultralytics-plugin>=0.0.26', # for meta descriptions and images, dates and authors
'mkdocs-ultralytics-plugin>=0.0.27', # for meta descriptions and images, dates and authors
],
'export': [
'coremltools>=7.0.b1',

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import subprocess
from pathlib import Path
import pytest
@ -81,18 +80,23 @@ def test_predict_sam():
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_model_tune():
subprocess.run('pip install ray[tune]'.split(), check=True)
def test_model_ray_tune():
with contextlib.suppress(RuntimeError): # RuntimeError may be caused by out-of-memory
YOLO('yolov8n-cls.yaml').tune(data='imagenet10',
YOLO('yolov8n-cls.yaml').tune(use_ray=True,
data='imagenet10',
grace_period=1,
max_samples=1,
iterations=1,
imgsz=32,
epochs=1,
plots=False,
device='cpu')
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_model_tune():
YOLO('yolov8n.pt').tune(data='coco8.yaml', imgsz=32, epochs=1, iterations=1, device='cpu')
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_pycocotools():
from ultralytics.models.yolo.detect import DetectionValidator

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.164'
__version__ = '8.0.165'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM

@ -146,7 +146,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
return IterableSimpleNamespace(**cfg)
def get_save_dir(args):
def get_save_dir(args, name=None):
"""Return save_dir as created from train/val/predict arguments."""
if getattr(args, 'save_dir', None):
@ -155,7 +155,7 @@ def get_save_dir(args):
from ultralytics.utils.files import increment_path
project = args.project or Path(SETTINGS['runs_dir']) / args.task
name = args.name or f'{args.mode}'
name = name or args.name or f'{args.mode}'
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
return Path(save_dir)

@ -344,6 +344,25 @@ class Model:
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
return self.metrics
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
"""
Runs hyperparameter tuning, optionally using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
"""
self._check_is_pytorch_model()
if use_ray:
from ultralytics.utils.tuner import run_ray_tune
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
else:
from .tuner import Tuner
custom = {} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self.model, iterations=iterations)
def to(self, device):
"""
@ -356,20 +375,6 @@ class Model:
self.model.to(device)
return self
def tune(self, *args, **kwargs):
"""
Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
"""
self._check_is_pytorch_model()
from ultralytics.utils.tuner import run_ray_tune
return run_ray_tune(self, *args, **kwargs)
@property
def names(self):
"""Returns class names of the loaded model."""

@ -0,0 +1,183 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
instance segmentation, image classification, pose estimation, and multi-object tracking.
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
Example:
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10)
```
"""
import random
import time
import numpy as np
from ultralytics import YOLO
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save
class Tuner:
"""
Class responsible for hyperparameter tuning of YOLO models.
The class evolves YOLO model hyperparameters over a given number of iterations
by mutating them according to the search space and retraining the model to evaluate their performance.
Attributes:
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
tune_dir (Path): Directory where evolution logs and results will be saved.
evolve_csv (Path): Path to the CSV file where evolution logs are saved.
Methods:
_mutate(hyp: dict) -> dict:
Mutates the given hyperparameters within the bounds specified in `self.space`.
__call__():
Executes the hyperparameter evolution across multiple iterations.
Example:
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10)
```
"""
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
"""
Initialize the Tuner with configurations.
Args:
args (dict, optional): Configuration for hyperparameter evolution.
"""
self.args = get_cfg(overrides=args)
self.space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': (1e-5, 1e-1),
'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (0.0, 0.95), # warmup initial momentum
'box': (0.02, 0.2), # box loss gain
'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (0.0, 45.0), # image rotation (+/- deg)
'translate': (0.0, 0.9), # image translation (+/- fraction)
'scale': (0.0, 0.9), # image scale (+/- gain)
'shear': (0.0, 10.0), # image shear (+/- deg)
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (0.0, 1.0), # image flip up-down (probability)
'fliplr': (0.0, 1.0), # image flip left-right (probability)
'mosaic': (0.0, 1.0), # image mixup (probability)
'mixup': (0.0, 1.0), # image mixup (probability)
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
self.tune_dir = get_save_dir(self.args, name='tune')
self.evolve_csv = self.tune_dir / 'evolve.csv'
self.callbacks = _callbacks or callbacks.get_default_callbacks()
callbacks.add_integration_callbacks(self)
LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.")
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.6, return_best=False):
"""
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
Args:
parent (str): Parent selection method: 'single' or 'weighted'.
n (int): Number of parents to consider.
mutation (float): Probability of a parameter mutation in any given iteration.
sigma (float): Standard deviation for Gaussian random number generator.
Returns:
(dict): A dictionary containing mutated hyperparameters.
"""
if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
# Select parent(s)
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
fitness = x[:, 0] # first column
n = min(n, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness)][:n] # top n mutations
if return_best:
return {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())}
fitness = x[:, 0] # first column
w = fitness - fitness.min() + 1E-6 # weights (sum > 0)
if parent == 'single' or len(x) == 1:
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
elif parent == 'weighted':
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
# Mutate
r = np.random # method
r.seed(int(time.time()))
g = np.array([self.space[k][0] for k in self.space.keys()]) # gains 0-1
ng = len(self.space)
v = np.ones(ng)
while all(v == 1): # mutate until a change occurs (prevent duplicates)
v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
else:
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
# Constrain to limits
for k, v in self.space.items():
hyp[k] = max(hyp[k], v[0]) # lower limit
hyp[k] = min(hyp[k], v[1]) # upper limit
hyp[k] = round(hyp[k], 5) # significant digits
return hyp
def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')):
"""
Executes the hyperparameter evolution process when the Tuner instance is called.
This method iterates through the number of iterations, performing the following steps in each iteration:
1. Load the existing hyperparameters or initialize new ones.
2. Mutate the hyperparameters using the `mutate` method.
3. Train a YOLO model with the mutated hyperparameters.
4. Log the fitness score and mutated hyperparameters to a CSV file.
Args:
model (YOLO): A pre-initialized YOLO model to be used for training.
iterations (int): The number of generations to run the evolution for.
Note:
The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores.
Ensure this path is set correctly in the Tuner instance.
"""
self.tune_dir.mkdir(parents=True, exist_ok=True)
for i in range(iterations):
# Mutate hyperparameters
mutated_hyp = self._mutate()
LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
# Initialize and train YOLOv8 model
model = YOLO('yolov8n.pt')
train_args = {**vars(self.args), **mutated_hyp}
results = model.train(**train_args)
# Save results and mutated_hyp to evolve_csv
headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n')
log_row = [results.fitness] + [mutated_hyp[k] for k in self.space.keys()]
with open(self.evolve_csv, 'a') as f:
f.write(headers + ','.join(map(str, log_row)) + '\n')
LOGGER.info(f'{prefix} All iterations complete. Results saved to {colorstr("bold", self.tune_dir)}')
best_hyp = self._mutate(return_best=True) # best hyps
yaml_save(self.tune_dir / 'best.yaml', best_hyp)
yaml_print(self.tune_dir / 'best.yaml')

@ -520,7 +520,6 @@ class Metric(SimpleClass):
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
update(results): Update metric attributes with new evaluation results.
"""
def __init__(self) -> None:

@ -17,6 +17,16 @@ from ultralytics.utils import LOGGER
class Profile(contextlib.ContextDecorator):
"""
YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
Example:
```python
from ultralytics.utils.ops import Profile
with Profile() as dt:
pass # slow operation here
print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
```
"""
def __init__(self, t=0.0):
@ -39,6 +49,9 @@ class Profile(contextlib.ContextDecorator):
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def __str__(self):
return f'Elapsed time is {self.t} s'
def time(self):
"""Get current time."""
if self.cuda:

@ -1,5 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import subprocess
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
@ -24,13 +26,23 @@ def run_ray_tune(model,
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
Example:
```python
from ultralytics import YOLO
# Load a YOLOv8n model
model = YOLO('yolov8n.pt')
# Start tuning hyperparameters for YOLOv8n training on the COCO8 dataset
result_grid = model.tune(data='coco8.yaml', use_ray=True)
```
"""
if train_args is None:
train_args = {}
try:
subprocess.run('pip install ray[tune]'.split(), check=True)
from ray import tune
from ray.air import RunConfig
from ray.air.integrations.wandb import WandbLoggerCallback

Loading…
Cancel
Save