You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
2.4 KiB
65 lines
2.4 KiB
class Enumerate_params_dict(): |
|
""" |
|
get several param_pools get indexes |
|
pool_ls: list of param_list, e.g:[[0.1,0.01,0.001],[1e-4,1e-5,1e-6]] for lr and wd pairs |
|
task_thread: all the experiments are divided into several threads to do |
|
|
|
""" |
|
def __init__(self, task_thread: int, if_single_id_task=False, **kwarg): |
|
def match_pairs(pair_ls, item_key): |
|
item_pool = kwarg[item_key] |
|
if type(item_pool) != list: |
|
raise ValueError('wrong type of the item pool:%s' % type(item_pool)) |
|
if len(item_pool) == 0: |
|
raise ValueError('item_pool should not be temp, length of the item_pool: %s' % len(item_pool)) |
|
new_data_ls = [] |
|
for pair in pair_ls: |
|
for i in range(len(item_pool)): |
|
new_data_ls.append(pair + ((item_key, i), )) |
|
return new_data_ls |
|
|
|
def divid_threads(ind_pairs, N: int): |
|
def get_param_dict(pair): |
|
temp = {} |
|
for p_key, ind in pair: |
|
temp[p_key] = kwarg[p_key][ind] |
|
return temp |
|
|
|
data_ls = [] |
|
for i in range(N): |
|
data_ls.append([]) |
|
cnt = 0 |
|
for pair in ind_pairs: |
|
param_dict = get_param_dict(pair) |
|
data_ls[cnt % len(data_ls)].append(param_dict) |
|
cnt += 1 |
|
return data_ls |
|
|
|
data_ls = [ |
|
(), |
|
] |
|
item_key_ls = list(kwarg.keys()) |
|
item_key_ls.sort() |
|
for item_key in item_key_ls: |
|
data_ls = match_pairs(data_ls, item_key) |
|
self.item_pool_dict = kwarg |
|
self.key_ind_pairs = data_ls |
|
if if_single_id_task: |
|
task_thread = len(data_ls) |
|
self.thread_pool = divid_threads(self.key_ind_pairs, task_thread) |
|
|
|
def get_thread(self, ind: int): |
|
thread_pool = self.thread_pool[ind % len(self.thread_pool)] |
|
return thread_pool |
|
|
|
@classmethod |
|
def demo(cls): |
|
test_id = 0 |
|
param_pool_dict = { |
|
'a': [0, 1, 2], |
|
'b': [0, 1, 2], |
|
} |
|
task_manager = Enumerate_params_dict(task_thread=2, if_single_id_task=True, **param_pool_dict) |
|
param_pool = task_manager.get_thread(ind=test_id) |
|
print(param_pool) |
|
print('thread_pool', len(task_manager.thread_pool), task_manager.thread_pool)
|
|
|