You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.6 KiB

"""Peform hyperparemeters search"""
import argparse
import collections
import itertools
import os
import sys
from common import utils
from experiment_dispatcher import dispatcher, tmux
PYTHON = sys.executable
parser = argparse.ArgumentParser()
parser.add_argument('--parent_dir', default='experiments', help='Directory containing params.json')
parser.add_argument('--id', default=1, type=int, help="Experiment id")
def launch_training_job(exp_dir, exp_name, session_name, param_pool_dict, params, start_id=0):
# Partition tmux windows automatically
tmux_ops = tmux.TmuxOps()
# Combining hyper-parameters and experiment ID automatically
task_manager = dispatcher.Enumerate_params_dict(task_thread=0, if_single_id_task=True, **param_pool_dict)
num_jobs = len([v for v in itertools.product(*param_pool_dict.values())])
exp_cmds = []
for job_id in range(num_jobs):
param_pool = task_manager.get_thread(ind=job_id)
for hyper_params in param_pool:
job_name = 'exp_{}'.format(job_id + start_id)
for k in hyper_params.keys():
params.dict[k] = hyper_params[k]
params.dict['model_dir'] = os.path.join(exp_dir, exp_name, job_name)
model_dir = params.dict['model_dir']
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# Write parameters in json file
json_path = os.path.join(model_dir, 'params.json')
params.save(json_path)
# Launch training with this config
cmd = 'rlaunch --cpu={} --memory={} --gpu={} -- python train.py --model_dir {}'.format(params.cpu, params.memory, params.gpu, model_dir)
exp_cmds.append(cmd)
tmux_ops.run_task(exp_cmds, task_name=exp_name, session_name=session_name)
def experiment():
# Load the "reference" parameters from parent_dir json file
args = parser.parse_args()
json_path = os.path.join(args.parent_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = utils.Params(json_path)
if args.id == 1:
# e.g. model and logs will be stored under 'experiment_learning_rate'
name = "learning_rate"
session_name = 'exp' # tmux session name, need pre-create
start_id = 0
exp_name = 'experiment_{}'.format(name)
param_pool_dict = collections.OrderedDict()
param_pool_dict['learning_rate'] = [0.0005, 0.001]
else:
raise NotImplementedError
launch_training_job(args.parent_dir, exp_name, session_name, param_pool_dict, params, start_id)
if __name__ == "__main__":
experiment()