You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
4.4 KiB
163 lines
4.4 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import json |
|
import os |
|
import re |
|
import sys |
|
|
|
from tap import Tap |
|
|
|
import dist |
|
|
|
line_sep = f'\n{"=" * 80}\n' |
|
|
|
|
|
class Args(Tap): |
|
# environment |
|
local_rank: int # useless |
|
exp_name: str |
|
data_path: str |
|
exp_dir: str |
|
log_txt_name: str = '/some/path/like/this/log.txt' |
|
resume: str = '' |
|
seed: int = 1 |
|
device: str = 'cpu' |
|
|
|
# key MIM hp |
|
mask: float = 0.6 |
|
mask2: float = -1 |
|
uni: bool = False |
|
pe: bool = False |
|
pn: int = 1 |
|
py: int = 4 |
|
# other MIM hp |
|
den: bool = False |
|
loss_l2: bool = True |
|
en_de_norm: str = 'bn' |
|
en_de_lin: bool = True |
|
|
|
# encoder |
|
model: str = 'res50' |
|
model_alias: str = 'res50' |
|
input_size: int = 224 |
|
sbn: bool = True |
|
# decoder |
|
dec_dim: int = 512 # [could be changed in `main.py`] |
|
double: bool = True |
|
hea: str = '0_1' |
|
cmid: int = 0 |
|
|
|
# pre-training hyperparameters |
|
glb_batch_size: int = 0 |
|
batch_size: int = 0 # batch size per GPU |
|
dp: float = 0.0 |
|
base_lr: float = 2e-4 |
|
lr: float = None |
|
wd: float = 0.04 |
|
wde: float = 0.2 |
|
ep: int = 1600 |
|
wp_ep: int = 40 |
|
clip: int = 5. |
|
opt: str = '' |
|
ada: float = 0. |
|
|
|
# data hyperparameters |
|
data_set: str = 'imn' |
|
rrc: float = 0.67 |
|
bs: int = 4096 |
|
num_workers: int = 8 |
|
|
|
# would be added during runtime |
|
cmd: str = '' |
|
commit_id: str = '' |
|
commit_msg: str = '' |
|
last_loss = 1e9 # [would be changed in `main.py`] |
|
cur_phase: str = '' # [would be changed in `main.py`] |
|
cur_ep: str = '' # [would be changed in `main.py`] |
|
remain_time: str = '' # [would be changed in `main.py`] |
|
finish_time: str = '' # [would be changed in `main.py`] |
|
|
|
first_logging: bool = True |
|
|
|
@property |
|
def is_convnext(self): |
|
return 'convnext' in self.model or 'cnx' in self.model |
|
|
|
@property |
|
def is_resnet(self): |
|
return 'res' in self.model or 'res' in self.model_alias |
|
|
|
def __str__(self): |
|
return re.sub(r"(\[LE-FT\]:\s*)('\s+')?", r'\1', super(Args, self).__str__()) |
|
|
|
def log_epoch(self): |
|
if not dist.is_local_master(): |
|
return |
|
|
|
if self.first_logging: |
|
self.first_logging = False |
|
with open(self.log_txt_name, 'w') as fp: |
|
json.dump({ |
|
'name': self.exp_name, 'cmd': self.cmd, 'commit_id': self.commit_id, |
|
'model': self.model, 'opt': self.opt, |
|
}, fp) |
|
print('', end='\n', file=fp) |
|
|
|
with open(self.log_txt_name, 'a') as fp: |
|
json.dump({ |
|
'cur': self.cur_phase, 'cur_ep': self.cur_ep, |
|
'last_L': self.last_loss, |
|
'rema': self.remain_time, 'fini': self.finish_time, |
|
}, fp) |
|
|
|
|
|
def init_dist_and_get_args(): |
|
from utils import misc |
|
from models import model_alias_to_fullname, model_fullname_to_alias |
|
|
|
# initialize |
|
args = Args(explicit_bool=True).parse_args() |
|
misc.init_distributed_environ(exp_dir=args.exp_dir) |
|
|
|
# update args |
|
args.cmd = ' '.join(sys.argv[1:]) |
|
args.commit_id = os.popen(f'git rev-parse HEAD').read().strip() |
|
args.commit_msg = os.popen(f'git log -1').read().strip().splitlines()[-1].strip() |
|
|
|
if args.model in model_alias_to_fullname.keys(): |
|
args.model = model_alias_to_fullname[args.model] |
|
args.model_alias = model_fullname_to_alias[args.model] |
|
|
|
args.device = dist.get_device() |
|
args.batch_size = args.bs // dist.get_world_size() |
|
args.glb_batch_size = args.batch_size * dist.get_world_size() |
|
|
|
if args.is_resnet: |
|
args.opt = args.opt or 'lamb' |
|
args.ada = args.ada or 0.95 |
|
|
|
if args.is_convnext: |
|
args.opt = args.opt or 'lamb' |
|
args.ada = args.ada or 0.999 |
|
args.en_de_norm = 'ln' |
|
|
|
args.opt = args.opt.lower() |
|
args.lr = args.base_lr * args.glb_batch_size / 256 |
|
args.wde = args.wde or args.wd |
|
|
|
if args.mask2 < 0: |
|
args.mask2 = args.mask |
|
args.mask, args.mask2 = min(args.mask, args.mask2), max(args.mask, args.mask2) |
|
|
|
if args.py <= 0: |
|
args.py = 1 |
|
|
|
args.hea = list(map(int, args.hea.split('_'))) |
|
|
|
args.log_txt_name = os.path.join(args.exp_dir, 'log.txt') |
|
|
|
return args
|
|
|