Add Hyperparameter evolution `Tuner()` class (#4599)
parent
7e99804263
commit
4bd62a299c
15 changed files with 405 additions and 93 deletions
@ -0,0 +1,96 @@ |
||||
--- |
||||
comments: true |
||||
description: Dive into hyperparameter tuning in Ultralytics YOLO models. Learn how to optimize performance using the Tuner class and genetic evolution. |
||||
keywords: Ultralytics, YOLO, Hyperparameter Tuning, Tuner Class, Genetic Evolution, Optimization |
||||
--- |
||||
|
||||
# Ultralytics YOLO Hyperparameter Tuning Guide |
||||
|
||||
## Introduction |
||||
|
||||
Hyperparameter tuning is not just a one-time set-up but an iterative process aimed at optimizing the machine learning model's performance metrics, such as accuracy, precision, and recall. In the context of Ultralytics YOLO, these hyperparameters could range from learning rate to architectural details, such as the number of layers or types of activation functions used. |
||||
|
||||
### What are Hyperparameters? |
||||
|
||||
Hyperparameters are high-level, structural settings for the algorithm. They are set prior to the training phase and remain constant during it. Here are some commonly tuned hyperparameters in Ultralytics YOLO: |
||||
|
||||
- **Learning Rate**: Determines the step size at each iteration while moving towards a minimum in the loss function. |
||||
- **Batch Size**: Number of training samples utilized in one iteration. |
||||
- **Number of Epochs**: An epoch is one complete forward and backward pass of all the training examples. |
||||
- **Architecture Specifics**: Such as anchor box sizes, number of layers, types of activation functions, etc. |
||||
|
||||
<p align="center"> |
||||
<img width="1000" src="https://user-images.githubusercontent.com/26833433/263858934-4f109a2f-82d9-4d08-8bd6-6fd1ff520bcd.png" alt="Hyperparameter Tuning Visual"> |
||||
</p> |
||||
|
||||
For a full list of augmentation hyperparameters used in YOLOv8 please refer to https://docs.ultralytics.com/usage/cfg/#augmentation. |
||||
|
||||
### Genetic Evolution and Mutation |
||||
|
||||
Ultralytics YOLO uses genetic algorithms to optimize hyperparameters. Genetic algorithms are inspired by the mechanism of natural selection and genetics. |
||||
|
||||
- **Mutation**: In the context of Ultralytics YOLO, mutation helps in locally searching the hyperparameter space by applying small, random changes to existing hyperparameters, producing new candidates for evaluation. |
||||
- **Crossover**: Although crossover is a popular genetic algorithm technique, it is not currently used in Ultralytics YOLO for hyperparameter tuning. The focus is mainly on mutation for generating new hyperparameter sets. |
||||
|
||||
## Preparing for Hyperparameter Tuning |
||||
|
||||
Before you begin the tuning process, it's important to: |
||||
|
||||
1. **Identify the Metrics**: Determine the metrics you will use to evaluate the model's performance. This could be AP50, F1-score, or others. |
||||
2. **Set the Tuning Budget**: Define how much computational resources you're willing to allocate. Hyperparameter tuning can be computationally intensive. |
||||
|
||||
## Steps Involved |
||||
|
||||
### Initialize Hyperparameters |
||||
|
||||
Start with a reasonable set of initial hyperparameters. This could either be the default hyperparameters set by Ultralytics YOLO or something based on your domain knowledge or previous experiments. |
||||
|
||||
### Mutate Hyperparameters |
||||
|
||||
Use the `_mutate` method to produce a new set of hyperparameters based on the existing set. |
||||
|
||||
### Train Model |
||||
|
||||
Training is performed using the mutated set of hyperparameters. The training performance is then assessed. |
||||
|
||||
### Evaluate Model |
||||
|
||||
Use metrics like AP50, F1-score, or custom metrics to evaluate the model's performance. |
||||
|
||||
### Log Results |
||||
|
||||
It's crucial to log both the performance metrics and the corresponding hyperparameters for future reference. |
||||
|
||||
### Repeat |
||||
|
||||
The process is repeated until either the set number of iterations is reached or the performance metric is satisfactory. |
||||
|
||||
## Usage Example |
||||
|
||||
Here's how to use the `model.tune()` method to utilize the `Tuner` class for hyperparameter tuning: |
||||
|
||||
!!! example "" |
||||
|
||||
=== "Python" |
||||
|
||||
```python |
||||
from ultralytics import YOLO |
||||
|
||||
# Initialize the YOLO model |
||||
model = YOLO('yolov8n.pt') |
||||
|
||||
# Perform hyperparameter tuning |
||||
model.tune(data='coco8.yaml', imgsz=640, epochs=30, iterations=300) |
||||
``` |
||||
|
||||
## Conclusion |
||||
|
||||
The hyperparameter tuning process in Ultralytics YOLO is simplified yet powerful, thanks to its genetic algorithm-based approach focused on mutation. Following the steps outlined in this guide will assist you in systematically tuning your model to achieve better performance. |
||||
|
||||
### Further Reading |
||||
|
||||
1. [Hyperparameter Optimization in Wikipedia](https://en.wikipedia.org/wiki/Hyperparameter_optimization) |
||||
2. [YOLOv5 Hyperparameter Evolution Guide](https://docs.ultralytics.com/yolov5/tutorials/hyperparameter_evolution/) |
||||
3. [Efficient Hyperparameter Tuning with Ray Tune and YOLOv8](https://docs.ultralytics.com/integrations/ray-tune/) |
||||
|
||||
For deeper insights, you can explore the `Tuner` class source code and accompanying documentation. Should you have any questions, feature requests, or need further assistance, feel free to reach out to our support team. |
@ -0,0 +1,183 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
""" |
||||
This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, |
||||
instance segmentation, image classification, pose estimation, and multi-object tracking. |
||||
|
||||
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters |
||||
that yield the best model performance. This is particularly crucial in deep learning models like YOLO, |
||||
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency. |
||||
|
||||
Example: |
||||
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. |
||||
```python |
||||
from ultralytics import YOLO |
||||
|
||||
model = YOLO('yolov8n.pt') |
||||
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10) |
||||
``` |
||||
""" |
||||
import random |
||||
import time |
||||
|
||||
import numpy as np |
||||
|
||||
from ultralytics import YOLO |
||||
from ultralytics.cfg import get_cfg, get_save_dir |
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save |
||||
|
||||
|
||||
class Tuner: |
||||
""" |
||||
Class responsible for hyperparameter tuning of YOLO models. |
||||
|
||||
The class evolves YOLO model hyperparameters over a given number of iterations |
||||
by mutating them according to the search space and retraining the model to evaluate their performance. |
||||
|
||||
Attributes: |
||||
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation. |
||||
tune_dir (Path): Directory where evolution logs and results will be saved. |
||||
evolve_csv (Path): Path to the CSV file where evolution logs are saved. |
||||
|
||||
Methods: |
||||
_mutate(hyp: dict) -> dict: |
||||
Mutates the given hyperparameters within the bounds specified in `self.space`. |
||||
|
||||
__call__(): |
||||
Executes the hyperparameter evolution across multiple iterations. |
||||
|
||||
Example: |
||||
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. |
||||
```python |
||||
from ultralytics import YOLO |
||||
|
||||
model = YOLO('yolov8n.pt') |
||||
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10) |
||||
``` |
||||
""" |
||||
|
||||
def __init__(self, args=DEFAULT_CFG, _callbacks=None): |
||||
""" |
||||
Initialize the Tuner with configurations. |
||||
|
||||
Args: |
||||
args (dict, optional): Configuration for hyperparameter evolution. |
||||
""" |
||||
self.args = get_cfg(overrides=args) |
||||
self.space = { |
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), |
||||
'lr0': (1e-5, 1e-1), |
||||
'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) |
||||
'momentum': (0.6, 0.98), # SGD momentum/Adam beta1 |
||||
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4 |
||||
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok) |
||||
'warmup_momentum': (0.0, 0.95), # warmup initial momentum |
||||
'box': (0.02, 0.2), # box loss gain |
||||
'cls': (0.2, 4.0), # cls loss gain (scale with pixels) |
||||
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction) |
||||
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction) |
||||
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction) |
||||
'degrees': (0.0, 45.0), # image rotation (+/- deg) |
||||
'translate': (0.0, 0.9), # image translation (+/- fraction) |
||||
'scale': (0.0, 0.9), # image scale (+/- gain) |
||||
'shear': (0.0, 10.0), # image shear (+/- deg) |
||||
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 |
||||
'flipud': (0.0, 1.0), # image flip up-down (probability) |
||||
'fliplr': (0.0, 1.0), # image flip left-right (probability) |
||||
'mosaic': (0.0, 1.0), # image mixup (probability) |
||||
'mixup': (0.0, 1.0), # image mixup (probability) |
||||
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability) |
||||
self.tune_dir = get_save_dir(self.args, name='tune') |
||||
self.evolve_csv = self.tune_dir / 'evolve.csv' |
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks() |
||||
callbacks.add_integration_callbacks(self) |
||||
LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.") |
||||
|
||||
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.6, return_best=False): |
||||
""" |
||||
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`. |
||||
|
||||
Args: |
||||
parent (str): Parent selection method: 'single' or 'weighted'. |
||||
n (int): Number of parents to consider. |
||||
mutation (float): Probability of a parameter mutation in any given iteration. |
||||
sigma (float): Standard deviation for Gaussian random number generator. |
||||
|
||||
Returns: |
||||
(dict): A dictionary containing mutated hyperparameters. |
||||
""" |
||||
if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate |
||||
# Select parent(s) |
||||
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1) |
||||
fitness = x[:, 0] # first column |
||||
n = min(n, len(x)) # number of previous results to consider |
||||
x = x[np.argsort(-fitness)][:n] # top n mutations |
||||
if return_best: |
||||
return {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())} |
||||
fitness = x[:, 0] # first column |
||||
w = fitness - fitness.min() + 1E-6 # weights (sum > 0) |
||||
if parent == 'single' or len(x) == 1: |
||||
# x = x[random.randint(0, n - 1)] # random selection |
||||
x = x[random.choices(range(n), weights=w)[0]] # weighted selection |
||||
elif parent == 'weighted': |
||||
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination |
||||
|
||||
# Mutate |
||||
r = np.random # method |
||||
r.seed(int(time.time())) |
||||
g = np.array([self.space[k][0] for k in self.space.keys()]) # gains 0-1 |
||||
ng = len(self.space) |
||||
v = np.ones(ng) |
||||
while all(v == 1): # mutate until a change occurs (prevent duplicates) |
||||
v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0) |
||||
hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())} |
||||
else: |
||||
hyp = {k: getattr(self.args, k) for k in self.space.keys()} |
||||
|
||||
# Constrain to limits |
||||
for k, v in self.space.items(): |
||||
hyp[k] = max(hyp[k], v[0]) # lower limit |
||||
hyp[k] = min(hyp[k], v[1]) # upper limit |
||||
hyp[k] = round(hyp[k], 5) # significant digits |
||||
|
||||
return hyp |
||||
|
||||
def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')): |
||||
""" |
||||
Executes the hyperparameter evolution process when the Tuner instance is called. |
||||
|
||||
This method iterates through the number of iterations, performing the following steps in each iteration: |
||||
1. Load the existing hyperparameters or initialize new ones. |
||||
2. Mutate the hyperparameters using the `mutate` method. |
||||
3. Train a YOLO model with the mutated hyperparameters. |
||||
4. Log the fitness score and mutated hyperparameters to a CSV file. |
||||
|
||||
Args: |
||||
model (YOLO): A pre-initialized YOLO model to be used for training. |
||||
iterations (int): The number of generations to run the evolution for. |
||||
|
||||
Note: |
||||
The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores. |
||||
Ensure this path is set correctly in the Tuner instance. |
||||
""" |
||||
|
||||
self.tune_dir.mkdir(parents=True, exist_ok=True) |
||||
for i in range(iterations): |
||||
# Mutate hyperparameters |
||||
mutated_hyp = self._mutate() |
||||
LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') |
||||
|
||||
# Initialize and train YOLOv8 model |
||||
model = YOLO('yolov8n.pt') |
||||
train_args = {**vars(self.args), **mutated_hyp} |
||||
results = model.train(**train_args) |
||||
|
||||
# Save results and mutated_hyp to evolve_csv |
||||
headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n') |
||||
log_row = [results.fitness] + [mutated_hyp[k] for k in self.space.keys()] |
||||
with open(self.evolve_csv, 'a') as f: |
||||
f.write(headers + ','.join(map(str, log_row)) + '\n') |
||||
|
||||
LOGGER.info(f'{prefix} All iterations complete. Results saved to {colorstr("bold", self.tune_dir)}') |
||||
best_hyp = self._mutate(return_best=True) # best hyps |
||||
yaml_save(self.tune_dir / 'best.yaml', best_hyp) |
||||
yaml_print(self.tune_dir / 'best.yaml') |
Loading…
Reference in new issue