|
|
|
@ -13,18 +13,20 @@ Example: |
|
|
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
|
model = YOLO('yolov8n.pt') |
|
|
|
|
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10) |
|
|
|
|
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False) |
|
|
|
|
``` |
|
|
|
|
""" |
|
|
|
|
import random |
|
|
|
|
import shutil |
|
|
|
|
import subprocess |
|
|
|
|
import time |
|
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from ultralytics import YOLO |
|
|
|
|
from ultralytics.cfg import get_cfg, get_save_dir |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save |
|
|
|
|
from ultralytics.utils.plotting import plot_tune_results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Tuner: |
|
|
|
@ -37,7 +39,7 @@ class Tuner: |
|
|
|
|
Attributes: |
|
|
|
|
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation. |
|
|
|
|
tune_dir (Path): Directory where evolution logs and results will be saved. |
|
|
|
|
evolve_csv (Path): Path to the CSV file where evolution logs are saved. |
|
|
|
|
tune_csv (Path): Path to the CSV file where evolution logs are saved. |
|
|
|
|
|
|
|
|
|
Methods: |
|
|
|
|
_mutate(hyp: dict) -> dict: |
|
|
|
@ -52,7 +54,7 @@ class Tuner: |
|
|
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
|
model = YOLO('yolov8n.pt') |
|
|
|
|
model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10, val=False, cache=True) |
|
|
|
|
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False) |
|
|
|
|
``` |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
@ -64,22 +66,23 @@ class Tuner: |
|
|
|
|
args (dict, optional): Configuration for hyperparameter evolution. |
|
|
|
|
""" |
|
|
|
|
self.args = get_cfg(overrides=args) |
|
|
|
|
self.space = { # key: (min, max, gain(optionaL)) |
|
|
|
|
self.space = { # key: (min, max, gain(optional)) |
|
|
|
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), |
|
|
|
|
'lr0': (1e-5, 1e-1), |
|
|
|
|
'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) |
|
|
|
|
'momentum': (0.6, 0.98, 0.3), # SGD momentum/Adam beta1 |
|
|
|
|
'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf) |
|
|
|
|
'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1 |
|
|
|
|
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4 |
|
|
|
|
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok) |
|
|
|
|
'warmup_momentum': (0.0, 0.95), # warmup initial momentum |
|
|
|
|
'box': (0.02, 0.2), # box loss gain |
|
|
|
|
'box': (1.0, 20.0), # box loss gain |
|
|
|
|
'cls': (0.2, 4.0), # cls loss gain (scale with pixels) |
|
|
|
|
'dfl': (0.4, 6.0), # dfl loss gain |
|
|
|
|
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction) |
|
|
|
|
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction) |
|
|
|
|
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction) |
|
|
|
|
'degrees': (0.0, 45.0), # image rotation (+/- deg) |
|
|
|
|
'translate': (0.0, 0.9), # image translation (+/- fraction) |
|
|
|
|
'scale': (0.0, 0.9), # image scale (+/- gain) |
|
|
|
|
'scale': (0.0, 0.95), # image scale (+/- gain) |
|
|
|
|
'shear': (0.0, 10.0), # image shear (+/- deg) |
|
|
|
|
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 |
|
|
|
|
'flipud': (0.0, 1.0), # image flip up-down (probability) |
|
|
|
@ -87,11 +90,13 @@ class Tuner: |
|
|
|
|
'mosaic': (0.0, 1.0), # image mixup (probability) |
|
|
|
|
'mixup': (0.0, 1.0), # image mixup (probability) |
|
|
|
|
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability) |
|
|
|
|
self.tune_dir = get_save_dir(self.args, name='_tune') |
|
|
|
|
self.evolve_csv = self.tune_dir / 'evolve.csv' |
|
|
|
|
self.tune_dir = get_save_dir(self.args, name='tune') |
|
|
|
|
self.tune_csv = self.tune_dir / 'tune_results.csv' |
|
|
|
|
self.callbacks = _callbacks or callbacks.get_default_callbacks() |
|
|
|
|
self.prefix = colorstr('Tuner: ') |
|
|
|
|
callbacks.add_integration_callbacks(self) |
|
|
|
|
LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.") |
|
|
|
|
LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" |
|
|
|
|
f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning') |
|
|
|
|
|
|
|
|
|
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2): |
|
|
|
|
""" |
|
|
|
@ -106,9 +111,9 @@ class Tuner: |
|
|
|
|
Returns: |
|
|
|
|
(dict): A dictionary containing mutated hyperparameters. |
|
|
|
|
""" |
|
|
|
|
if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate |
|
|
|
|
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate |
|
|
|
|
# Select parent(s) |
|
|
|
|
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1) |
|
|
|
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) |
|
|
|
|
fitness = x[:, 0] # first column |
|
|
|
|
n = min(n, len(x)) # number of previous results to consider |
|
|
|
|
x = x[np.argsort(-fitness)][:n] # top n mutations |
|
|
|
@ -139,7 +144,7 @@ class Tuner: |
|
|
|
|
|
|
|
|
|
return hyp |
|
|
|
|
|
|
|
|
|
def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')): |
|
|
|
|
def __call__(self, model=None, iterations=10, cleanup=True): |
|
|
|
|
""" |
|
|
|
|
Executes the hyperparameter evolution process when the Tuner instance is called. |
|
|
|
|
|
|
|
|
@ -152,54 +157,68 @@ class Tuner: |
|
|
|
|
Args: |
|
|
|
|
model (Model): A pre-initialized YOLO model to be used for training. |
|
|
|
|
iterations (int): The number of generations to run the evolution for. |
|
|
|
|
cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning. |
|
|
|
|
|
|
|
|
|
Note: |
|
|
|
|
The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores. |
|
|
|
|
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores. |
|
|
|
|
Ensure this path is set correctly in the Tuner instance. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
|
|
best_save_dir, best_metrics = None, None |
|
|
|
|
self.tune_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
(self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True) |
|
|
|
|
for i in range(iterations): |
|
|
|
|
# Mutate hyperparameters |
|
|
|
|
mutated_hyp = self._mutate() |
|
|
|
|
LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') |
|
|
|
|
LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') |
|
|
|
|
|
|
|
|
|
metrics = {} |
|
|
|
|
train_args = {**vars(self.args), **mutated_hyp} |
|
|
|
|
save_dir = get_save_dir(get_cfg(train_args)) |
|
|
|
|
try: |
|
|
|
|
# Train YOLO model with mutated hyperparameters |
|
|
|
|
train_args = {**vars(self.args), **mutated_hyp} |
|
|
|
|
results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args) |
|
|
|
|
fitness = results.fitness |
|
|
|
|
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang) |
|
|
|
|
weights_dir = save_dir / 'weights' |
|
|
|
|
cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())] |
|
|
|
|
assert subprocess.run(cmd, check=True).returncode == 0, 'training failed' |
|
|
|
|
ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt') |
|
|
|
|
metrics = torch.load(ckpt_file)['train_metrics'] |
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i}\n{e}') |
|
|
|
|
fitness = 0.0 |
|
|
|
|
LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}') |
|
|
|
|
|
|
|
|
|
# Save results and mutated_hyp to evolve_csv |
|
|
|
|
# Save results and mutated_hyp to CSV |
|
|
|
|
fitness = metrics.get('fitness', 0.0) |
|
|
|
|
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()] |
|
|
|
|
headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n') |
|
|
|
|
with open(self.evolve_csv, 'a') as f: |
|
|
|
|
headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n') |
|
|
|
|
with open(self.tune_csv, 'a') as f: |
|
|
|
|
f.write(headers + ','.join(map(str, log_row)) + '\n') |
|
|
|
|
|
|
|
|
|
# Print tuning results |
|
|
|
|
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1) |
|
|
|
|
# Get best results |
|
|
|
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) |
|
|
|
|
fitness = x[:, 0] # first column |
|
|
|
|
best_idx = fitness.argmax() |
|
|
|
|
best_is_current = best_idx == i |
|
|
|
|
if best_is_current: |
|
|
|
|
best_save_dir = results.save_dir |
|
|
|
|
best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()} |
|
|
|
|
header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n' |
|
|
|
|
f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n' |
|
|
|
|
f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' |
|
|
|
|
f'{prefix} Best fitness metrics are {best_metrics}\n' |
|
|
|
|
f'{prefix} Best fitness model is {best_save_dir}\n' |
|
|
|
|
f'{prefix} Best fitness hyperparameters are printed below.\n') |
|
|
|
|
|
|
|
|
|
best_save_dir = save_dir |
|
|
|
|
best_metrics = {k: round(v, 5) for k, v in metrics.items()} |
|
|
|
|
for ckpt in weights_dir.glob('*.pt'): |
|
|
|
|
shutil.copy2(ckpt, self.tune_dir / 'weights') |
|
|
|
|
elif cleanup: |
|
|
|
|
shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space |
|
|
|
|
|
|
|
|
|
# Plot tune results |
|
|
|
|
plot_tune_results(self.tune_csv) |
|
|
|
|
|
|
|
|
|
# Save and print tune results |
|
|
|
|
header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n' |
|
|
|
|
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n' |
|
|
|
|
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' |
|
|
|
|
f'{self.prefix}Best fitness metrics are {best_metrics}\n' |
|
|
|
|
f'{self.prefix}Best fitness model is {best_save_dir}\n' |
|
|
|
|
f'{self.prefix}Best fitness hyperparameters are printed below.\n') |
|
|
|
|
LOGGER.info('\n' + header) |
|
|
|
|
|
|
|
|
|
# Save turning results |
|
|
|
|
data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())} |
|
|
|
|
header = header.replace(prefix, '#').replace('[1m/', '').replace('[0m', '') + '\n' |
|
|
|
|
yaml_save(self.tune_dir / 'best.yaml', data=data, header=header) |
|
|
|
|
yaml_print(self.tune_dir / 'best.yaml') |
|
|
|
|
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())} |
|
|
|
|
yaml_save(self.tune_dir / 'best_hyperparameters.yaml', |
|
|
|
|
data=data, |
|
|
|
|
header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n') |
|
|
|
|
yaml_print(self.tune_dir / 'best_hyperparameters.yaml') |
|
|
|
|