|
|
|
@ -3,7 +3,7 @@ |
|
|
|
|
import subprocess |
|
|
|
|
|
|
|
|
|
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS |
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_ray_tune( |
|
|
|
@ -40,7 +40,7 @@ def run_ray_tune( |
|
|
|
|
train_args = {} |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
subprocess.run("pip install ray[tune]".split(), check=True) |
|
|
|
|
subprocess.run("pip install ray[tune]<=2.9.3".split(), check=True) # do not add single quotes here |
|
|
|
|
|
|
|
|
|
import ray |
|
|
|
|
from ray import tune |
|
|
|
@ -48,7 +48,7 @@ def run_ray_tune( |
|
|
|
|
from ray.air.integrations.wandb import WandbLoggerCallback |
|
|
|
|
from ray.tune.schedulers import ASHAScheduler |
|
|
|
|
except ImportError: |
|
|
|
|
raise ModuleNotFoundError('Tuning hyperparameters requires Ray Tune. Install with: pip install "ray[tune]"') |
|
|
|
|
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]<=2.9.3"') |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
import wandb |
|
|
|
@ -57,6 +57,7 @@ def run_ray_tune( |
|
|
|
|
except (ImportError, AssertionError): |
|
|
|
|
wandb = False |
|
|
|
|
|
|
|
|
|
checks.check_version(ray.__version__, "<=2.9.3", "ray") |
|
|
|
|
default_space = { |
|
|
|
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), |
|
|
|
|
"lr0": tune.uniform(1e-5, 1e-1), |
|
|
|
|