You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
7.8 KiB

2 years ago
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
2 years ago
import math
import sys
import time
from functools import partial
from typing import List
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dist
import encoder
from decoder import LightDecoder
from models import build_sparse_encoder
from sampler import DistInfiniteBatchSampler, worker_init_fn
from spark import SparK
from utils import arg_util, misc, lamb
from utils.imagenet import build_imagenet_pretrain
2 years ago
from utils.lr_control import lr_wd_annealing, get_param_groups
def main_pt():
args: arg_util.Args = arg_util.init_dist_and_get_args()
2 years ago
print(f'initial args:\n{str(args)}')
args.log_epoch()
2 years ago
# build data
print(f'[build data for pre-training] ...\n')
dataset_train = build_imagenet_pretrain(args.data_path, args.input_size)
2 years ago
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
2 years ago
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn
)
itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
# build encoder and decoder
2 years ago
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
model_without_ddp = SparK(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
densify_norm=args.densify_norm, sbn=args.sbn, hierarchy=args.hierarchy,
).to(args.device)
print(f'[PT model] model = {model_without_ddp}\n')
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
2 years ago
# build optimizer and lr_scheduler
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'})
2 years ago
opt_clz = {
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0),
2 years ago
}[args.opt]
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
# try to resume
ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
if ep_start >= args.ep: # load from a complete checkpoint file
2 years ago
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
else: # perform pre-training
tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
2 years ago
min_loss = 1e9
print(f'[PT start] from ep{ep_start}')
2 years ago
pt_start_time = time.time()
for ep in range(ep_start, args.ep):
ep_start_time = time.time()
tb_lg.set_step(ep * iters_train)
2 years ago
if hasattr(itrt_train, 'set_epoch'):
itrt_train.set_epoch(ep)
stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
2 years ago
last_loss = stats['last_loss']
min_loss = min(min_loss, last_loss)
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
misc.save_checkpoint(f'{args.model}_still_pretraining.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
misc.save_checkpoint_for_finetune(f'{args.model}_1kpretrained.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
remain_secs = (args.ep-1 - ep) * ep_cost
remain_time = datetime.timedelta(seconds=round(remain_secs))
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
2 years ago
args.cur_ep = f'{ep + 1}/{args.ep}'
2 years ago
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
args.last_loss = last_loss
args.log_epoch()
tb_lg.update(min_loss=min_loss, head='train', step=ep)
tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
tb_lg.flush()
2 years ago
# finish pre-training
tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
tb_lg.flush()
print(f'final args:\n{str(args)}')
2 years ago
print('\n\n')
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
2 years ago
print('\n\n')
tb_lg.close()
time.sleep(10)
2 years ago
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
args.log_epoch()
def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
2 years ago
model.train()
me = misc.MetricLogger(delimiter=' ')
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
header = f'[PT] Epoch {ep}:'
2 years ago
optimizer.zero_grad()
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
late_clipping = hasattr(optimizer, 'global_grad_norm')
2 years ago
if early_clipping:
params_req_grad = [p for p in model.parameters() if p.requires_grad]
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
# adjust lr and wd
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
2 years ago
# forward and backward
inp = inp.to(args.device, non_blocking=True)
SparK.forward
_, _, loss = model(inp)
2 years ago
optimizer.zero_grad()
loss.backward()
loss = loss.item()
if not math.isfinite(loss):
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
sys.exit(-1)
2 years ago
# optimize
grad_norm = None
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
2 years ago
optimizer.step()
if late_clipping: grad_norm = optimizer.global_grad_norm
torch.cuda.synchronize()
2 years ago
# log
me.update(last_loss=loss)
me.update(max_lr=max_lr)
tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
2 years ago
if grad_norm is not None:
me.update(orig_norm=grad_norm)
tb_lg.update(orig_norm=grad_norm, head='train_hp')
tb_lg.set_step()
2 years ago
me.synchronize_between_processes()
return {k: meter.global_avg for k, meter in me.meters.items()}
2 years ago
if __name__ == '__main__':
main_pt()