`ultralytics 8.0.168` save `Tuner` results every iteration (#4680)

pull/4682/head v8.0.168
Glenn Jocher 1 year ago committed by GitHub
parent 4b6147dd6f
commit 263bfd1e93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      docs/yolov5/tutorials/hyperparameter_evolution.md
  2. 2
      ultralytics/__init__.py
  3. 39
      ultralytics/engine/tuner.py
  4. 5
      ultralytics/utils/__init__.py
  5. 2
      ultralytics/utils/autobatch.py

@ -100,7 +100,6 @@ done
```
The default evolution settings will run the base scenario 300 times, i.e. for 300 generations. You can modify generations via the `--evolve` argument, i.e. `python train.py --evolve 1000`.
https://github.com/ultralytics/yolov5/blob/6a3ee7cf03efb17fbffde0e68b1a854e80fe3213/train.py#L608
The main genetic operators are **crossover** and **mutation**. In this work mutation is used, with an 80% probability and a 0.04 variance to create new offspring based on a combination of the best parents from all previous generations. Results are logged to `runs/evolve/exp/evolve.csv`, and the highest fitness offspring is saved every generation as `runs/evolve/hyp_evolved.yaml`:

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.167'
__version__ = '8.0.168'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM

@ -159,16 +159,18 @@ class Tuner:
"""
t0 = time.time()
best_save_dir, best_metrics = None, None
self.tune_dir.mkdir(parents=True, exist_ok=True)
for i in range(iterations):
# Mutate hyperparameters
mutated_hyp = self._mutate()
LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
# Initialize and train YOLOv8 model
try:
# Train YOLO model with mutated hyperparameters
train_args = {**vars(self.args), **mutated_hyp}
fitness = (deepcopy(model) or YOLO(self.args.model)).train(**train_args).fitness # results.fitness
results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args)
fitness = results.fitness
except Exception as e:
LOGGER.warning(f'WARNING ❌ training failure for hyperparameter tuning iteration {i}\n{e}')
fitness = 0.0
@ -179,14 +181,25 @@ class Tuner:
with open(self.evolve_csv, 'a') as f:
f.write(headers + ','.join(map(str, log_row)) + '\n')
# Print tuning results
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
fitness = x[:, 0] # first column
i = np.argsort(-fitness)[0] # best fitness index
LOGGER.info(f'\n{prefix} All iterations complete ✅ ({time.time() - t0:.2f}s)\n'
f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n'
f'{prefix} Best fitness={fitness[i]} observed at iteration {i}')
# Save turning results
yaml_save(self.tune_dir / 'best.yaml', data={k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())})
yaml_print(self.tune_dir / 'best.yaml')
# Print tuning results
x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
fitness = x[:, 0] # first column
best_idx = fitness.argmax()
best_is_current = best_idx == i
if best_is_current:
best_save_dir = results.save_dir
best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()}
header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n'
f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
f'{prefix} Best fitness metrics are {best_metrics}\n'
f'{prefix} Best fitness model is {best_save_dir}\n'
f'{prefix} Best fitness hyperparameters are printed below.\n')
LOGGER.info('\n' + header)
# Save turning results
data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())}
header = header.replace(prefix, '#').replace('/', '').replace('', '') + '\n'
yaml_save(self.tune_dir / 'best.yaml', data=data, header=header)
yaml_print(self.tune_dir / 'best.yaml')

@ -286,13 +286,14 @@ class ThreadingLocked:
return decorated
def yaml_save(file='data.yaml', data=None):
def yaml_save(file='data.yaml', data=None, header=''):
"""
Save YAML data to a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
data (dict): Data to save in YAML format.
header (str, optional): YAML header to add.
Returns:
(None): Data is saved to the specified file.
@ -311,6 +312,8 @@ def yaml_save(file='data.yaml', data=None):
# Dump data to file in YAML format
with open(file, 'w') as f:
if header:
f.write(header)
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)

@ -29,7 +29,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
def autobatch(model, imgsz=640, fraction=0.67, batch_size=DEFAULT_CFG.batch):
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
"""
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.

Loading…
Cancel
Save