You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
157 lines
6.9 KiB
157 lines
6.9 KiB
2 years ago
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import math
|
||
|
import sys
|
||
|
import time
|
||
|
from functools import partial
|
||
|
from typing import List
|
||
|
|
||
|
import torch
|
||
|
from torch.nn.parallel import DistributedDataParallel
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
import dist
|
||
|
import encoder
|
||
|
from decoder import LightDecoder
|
||
|
from models import build_sparse_encoder
|
||
|
from sampler import DistInfiniteBatchSampler, worker_init_fn
|
||
|
from spark import SparK
|
||
|
from utils import meta, misc, optim
|
||
|
from utils.imagenet import build_imagenet
|
||
|
from utils.lr_control import lr_wd_annealing, get_param_groups
|
||
|
|
||
|
|
||
|
def main_pt():
|
||
|
args: meta.Args = meta.init_dist_and_get_args()
|
||
|
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
|
||
|
print(f'initial args:\n{str(args)}')
|
||
|
args.log_epoch()
|
||
|
|
||
|
# build data
|
||
|
print(f'[build data for pre-training] ...\n')
|
||
|
dataset_train, _ = build_imagenet('pt', args.data_path, args.data_set, args.input_size, eval_crop_pct=None, rrc=args.rrc)
|
||
|
data_loader_train = DataLoader(
|
||
|
dataset=dataset_train, num_workers=args.num_workers, pin_memory=True,
|
||
|
batch_sampler=DistInfiniteBatchSampler(
|
||
|
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed,
|
||
|
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
|
||
|
), worker_init_fn=worker_init_fn
|
||
|
)
|
||
|
itrt_train = iter(data_loader_train)
|
||
|
iters_train = len(data_loader_train)
|
||
|
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}')
|
||
|
|
||
|
# build models (encoder, decoder, and other components)
|
||
|
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
|
||
|
dec = LightDecoder(args.dec_dim, enc.downsample_raito, double=args.double, heavy=args.hea, cmid=args.cmid, sbn=args.sbn)
|
||
|
spark = SparK(
|
||
|
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, mask_ratio2=args.mask2, uniform=args.uni,
|
||
|
using_pe=args.pe, pix_norm=args.pn, dense_loss=args.den, loss_l2=args.loss_l2,
|
||
|
en_de_norm=args.en_de_norm, en_de_lin=args.en_de_lin, sbn=args.sbn, pyramid=args.py,
|
||
|
)
|
||
|
print(f'[PT model] model = {spark}\n')
|
||
|
spark.to(args.device)
|
||
|
model: DistributedDataParallel = DistributedDataParallel(spark, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
||
|
model_without_ddp: SparK = model.module
|
||
|
|
||
|
# build optimizer and lr_scheduler
|
||
|
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'pos_embed', 'mask_token', 'gamma'}, lr_scale=0)
|
||
|
opt_clz = {
|
||
|
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
|
||
|
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
|
||
|
'lamb': partial(optim.TimmLAMB, betas=(0.9, args.ada), max_grad_norm=args.clip),
|
||
|
}[args.opt]
|
||
|
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
|
||
|
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
|
||
|
|
||
|
# try to resume
|
||
|
next_ep, performance_desc = misc.load_checkpoint(args.resume, model_without_ddp, optimizer) if len(args.resume) else (0, '[no performance_desc]')
|
||
|
if next_ep >= args.ep:
|
||
|
# load from a complete checkpoint file
|
||
|
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
|
||
|
else:
|
||
|
# perform pre-training
|
||
|
start_time = time.time()
|
||
|
min_loss = 1e9
|
||
|
print(f'[PT start] from ep{next_ep}')
|
||
|
|
||
|
for ep in range(next_ep, args.ep):
|
||
|
if hasattr(itrt_train, 'set_epoch'):
|
||
|
itrt_train.set_epoch(ep)
|
||
|
|
||
|
stats, (sec, remain_time, finish_time) = pre_train_one_ep(ep, args, itrt_train, iters_train, model, optimizer)
|
||
|
last_loss = stats['last_loss']
|
||
|
min_loss = min(min_loss, last_loss)
|
||
|
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
|
||
|
print(f' [*] [ep{ep}] Min/Last Recon Loss: {performance_desc}, Remain: {remain_time}, Finish: {finish_time}')
|
||
|
|
||
|
args.cur_phase = 'PT'
|
||
|
args.cur_ep = f'{ep+1}/{args.ep}'
|
||
|
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
|
||
|
args.last_loss = last_loss
|
||
|
args.log_epoch()
|
||
|
misc.save_checkpoint(f'ckpt-last.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
|
||
|
|
||
|
# finish pre-training
|
||
|
print('\n\n')
|
||
|
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - start_time) / 60 / 60:.1f}h')
|
||
|
print('\n\n')
|
||
|
|
||
|
misc.save_checkpoint(f'ckpt-final.pth', args, args.ep-1, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
|
||
|
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
|
||
|
args.log_epoch()
|
||
|
|
||
|
|
||
|
def pre_train_one_ep(ep, args, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
|
||
|
model.train()
|
||
|
me = misc.MetricLogger(delimiter=' ')
|
||
|
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
|
||
|
header = f'[PT] Epoch: [{ep:3d}/{args.ep}]'
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
|
||
|
late_clipping = args.clip > 0 and hasattr(optimizer, 'global_grad_norm')
|
||
|
if early_clipping:
|
||
|
params_req_grad = [p for p in model.parameters() if p.requires_grad]
|
||
|
|
||
|
# for every batch do:
|
||
|
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
|
||
|
# adjust lr and wd
|
||
|
g_it = it + ep*iters_train
|
||
|
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, g_it, args.wp_ep*iters_train, args.ep*iters_train)
|
||
|
|
||
|
# forward and backward
|
||
|
inp = inp.to(args.device, non_blocking=True)
|
||
|
SparK.forward
|
||
|
active_ex, rec, loss = model(inp)
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
loss = loss.item()
|
||
|
if not math.isfinite(loss):
|
||
|
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
|
||
|
sys.exit(-1)
|
||
|
|
||
|
# optimize
|
||
|
grad_norm = None
|
||
|
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip)
|
||
|
optimizer.step()
|
||
|
if late_clipping: grad_norm = optimizer.global_grad_norm
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
# log
|
||
|
me.update(last_loss=loss)
|
||
|
me.update(max_lr=max_lr)
|
||
|
if grad_norm is not None:
|
||
|
me.update(orig_norm=grad_norm)
|
||
|
|
||
|
me.synchronize_between_processes()
|
||
|
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds((args.ep-1-ep) * (iters_train+10))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main_pt()
|