|
|
|
@ -2,8 +2,8 @@ |
|
|
|
|
|
|
|
|
|
import subprocess |
|
|
|
|
|
|
|
|
|
from ultralytics.cfg import TASK2DATA, TASK2METRIC |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS |
|
|
|
|
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_ray_tune(model, |
|
|
|
@ -93,9 +93,10 @@ def run_ray_tune(model, |
|
|
|
|
Returns: |
|
|
|
|
None. |
|
|
|
|
""" |
|
|
|
|
model._reset_callbacks() |
|
|
|
|
model.reset_callbacks() |
|
|
|
|
config.update(train_args) |
|
|
|
|
model.train(**config) |
|
|
|
|
results = model.train(**config) |
|
|
|
|
return results.results_dict |
|
|
|
|
|
|
|
|
|
# Get search space |
|
|
|
|
if not space: |
|
|
|
@ -123,10 +124,12 @@ def run_ray_tune(model, |
|
|
|
|
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else [] |
|
|
|
|
|
|
|
|
|
# Create the Ray Tune hyperparameter search tuner |
|
|
|
|
tune_dir = get_save_dir(DEFAULT_CFG, name='tune') |
|
|
|
|
tune_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
tuner = tune.Tuner(trainable_with_resources, |
|
|
|
|
param_space=space, |
|
|
|
|
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), |
|
|
|
|
run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune')) |
|
|
|
|
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir)) |
|
|
|
|
|
|
|
|
|
# Run the hyperparameter search |
|
|
|
|
tuner.fit() |
|
|
|
|