You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.6 KiB
74 lines
2.6 KiB
3 years ago
|
"""Peform hyperparemeters search"""
|
||
|
|
||
|
import argparse
|
||
|
import collections
|
||
|
import itertools
|
||
|
import os
|
||
|
import sys
|
||
|
|
||
|
from common import utils
|
||
|
from experiment_dispatcher import dispatcher, tmux
|
||
|
|
||
|
PYTHON = sys.executable
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--parent_dir', default='experiments', help='Directory containing params.json')
|
||
|
parser.add_argument('--id', default=1, type=int, help="Experiment id")
|
||
|
|
||
|
|
||
|
def launch_training_job(exp_dir, exp_name, session_name, param_pool_dict, params, start_id=0):
|
||
|
# Partition tmux windows automatically
|
||
|
tmux_ops = tmux.TmuxOps()
|
||
|
# Combining hyper-parameters and experiment ID automatically
|
||
|
task_manager = dispatcher.Enumerate_params_dict(task_thread=0, if_single_id_task=True, **param_pool_dict)
|
||
|
|
||
|
num_jobs = len([v for v in itertools.product(*param_pool_dict.values())])
|
||
|
exp_cmds = []
|
||
|
|
||
|
for job_id in range(num_jobs):
|
||
|
param_pool = task_manager.get_thread(ind=job_id)
|
||
|
for hyper_params in param_pool:
|
||
|
job_name = 'exp_{}'.format(job_id + start_id)
|
||
|
for k in hyper_params.keys():
|
||
|
params.dict[k] = hyper_params[k]
|
||
|
|
||
|
params.dict['model_dir'] = os.path.join(exp_dir, exp_name, job_name)
|
||
|
model_dir = params.dict['model_dir']
|
||
|
|
||
|
if not os.path.exists(model_dir):
|
||
|
os.makedirs(model_dir)
|
||
|
|
||
|
# Write parameters in json file
|
||
|
json_path = os.path.join(model_dir, 'params.json')
|
||
|
params.save(json_path)
|
||
|
|
||
|
# Launch training with this config
|
||
|
cmd = 'rlaunch --cpu={} --memory={} --gpu={} -- python train.py --model_dir {}'.format(params.cpu, params.memory, params.gpu, model_dir)
|
||
|
exp_cmds.append(cmd)
|
||
|
|
||
|
tmux_ops.run_task(exp_cmds, task_name=exp_name, session_name=session_name)
|
||
|
|
||
|
|
||
|
def experiment():
|
||
|
# Load the "reference" parameters from parent_dir json file
|
||
|
args = parser.parse_args()
|
||
|
json_path = os.path.join(args.parent_dir, 'params.json')
|
||
|
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
|
||
|
params = utils.Params(json_path)
|
||
|
|
||
|
if args.id == 1:
|
||
|
# e.g. model and logs will be stored under 'experiment_learning_rate'
|
||
|
name = "learning_rate"
|
||
|
session_name = 'exp' # tmux session name, need pre-create
|
||
|
start_id = 0
|
||
|
exp_name = 'experiment_{}'.format(name)
|
||
|
param_pool_dict = collections.OrderedDict()
|
||
|
param_pool_dict['learning_rate'] = [0.0005, 0.001]
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
launch_training_job(args.parent_dir, exp_name, session_name, param_pool_dict, params, start_id)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
experiment()
|