You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.4 KiB
66 lines
2.4 KiB
3 years ago
|
class Enumerate_params_dict():
|
||
|
"""
|
||
|
get several param_pools get indexes
|
||
|
pool_ls: list of param_list, e.g:[[0.1,0.01,0.001],[1e-4,1e-5,1e-6]] for lr and wd pairs
|
||
|
task_thread: all the experiments are divided into several threads to do
|
||
|
|
||
|
"""
|
||
|
def __init__(self, task_thread: int, if_single_id_task=False, **kwarg):
|
||
|
def match_pairs(pair_ls, item_key):
|
||
|
item_pool = kwarg[item_key]
|
||
|
if type(item_pool) != list:
|
||
|
raise ValueError('wrong type of the item pool:%s' % type(item_pool))
|
||
|
if len(item_pool) == 0:
|
||
|
raise ValueError('item_pool should not be temp, length of the item_pool: %s' % len(item_pool))
|
||
|
new_data_ls = []
|
||
|
for pair in pair_ls:
|
||
|
for i in range(len(item_pool)):
|
||
|
new_data_ls.append(pair + ((item_key, i), ))
|
||
|
return new_data_ls
|
||
|
|
||
|
def divid_threads(ind_pairs, N: int):
|
||
|
def get_param_dict(pair):
|
||
|
temp = {}
|
||
|
for p_key, ind in pair:
|
||
|
temp[p_key] = kwarg[p_key][ind]
|
||
|
return temp
|
||
|
|
||
|
data_ls = []
|
||
|
for i in range(N):
|
||
|
data_ls.append([])
|
||
|
cnt = 0
|
||
|
for pair in ind_pairs:
|
||
|
param_dict = get_param_dict(pair)
|
||
|
data_ls[cnt % len(data_ls)].append(param_dict)
|
||
|
cnt += 1
|
||
|
return data_ls
|
||
|
|
||
|
data_ls = [
|
||
|
(),
|
||
|
]
|
||
|
item_key_ls = list(kwarg.keys())
|
||
|
item_key_ls.sort()
|
||
|
for item_key in item_key_ls:
|
||
|
data_ls = match_pairs(data_ls, item_key)
|
||
|
self.item_pool_dict = kwarg
|
||
|
self.key_ind_pairs = data_ls
|
||
|
if if_single_id_task:
|
||
|
task_thread = len(data_ls)
|
||
|
self.thread_pool = divid_threads(self.key_ind_pairs, task_thread)
|
||
|
|
||
|
def get_thread(self, ind: int):
|
||
|
thread_pool = self.thread_pool[ind % len(self.thread_pool)]
|
||
|
return thread_pool
|
||
|
|
||
|
@classmethod
|
||
|
def demo(cls):
|
||
|
test_id = 0
|
||
|
param_pool_dict = {
|
||
|
'a': [0, 1, 2],
|
||
|
'b': [0, 1, 2],
|
||
|
}
|
||
|
task_manager = Enumerate_params_dict(task_thread=2, if_single_id_task=True, **param_pool_dict)
|
||
|
param_pool = task_manager.get_thread(ind=test_id)
|
||
|
print(param_pool)
|
||
|
print('thread_pool', len(task_manager.thread_pool), task_manager.thread_pool)
|