You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

157 lines
6.9 KiB

2 years ago
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import sys
import time
from functools import partial
from typing import List
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dist
import encoder
from decoder import LightDecoder
from models import build_sparse_encoder
from sampler import DistInfiniteBatchSampler, worker_init_fn
from spark import SparK
from utils import meta, misc, optim
from utils.imagenet import build_imagenet
from utils.lr_control import lr_wd_annealing, get_param_groups
def main_pt():
args: meta.Args = meta.init_dist_and_get_args()
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
print(f'initial args:\n{str(args)}')
args.log_epoch()
# build data
print(f'[build data for pre-training] ...\n')
dataset_train, _ = build_imagenet('pt', args.data_path, args.data_set, args.input_size, eval_crop_pct=None, rrc=args.rrc)
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.num_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn
)
itrt_train = iter(data_loader_train)
iters_train = len(data_loader_train)
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}')
# build models (encoder, decoder, and other components)
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
dec = LightDecoder(args.dec_dim, enc.downsample_raito, double=args.double, heavy=args.hea, cmid=args.cmid, sbn=args.sbn)
spark = SparK(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, mask_ratio2=args.mask2, uniform=args.uni,
using_pe=args.pe, pix_norm=args.pn, dense_loss=args.den, loss_l2=args.loss_l2,
en_de_norm=args.en_de_norm, en_de_lin=args.en_de_lin, sbn=args.sbn, pyramid=args.py,
)
print(f'[PT model] model = {spark}\n')
spark.to(args.device)
model: DistributedDataParallel = DistributedDataParallel(spark, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
model_without_ddp: SparK = model.module
# build optimizer and lr_scheduler
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'pos_embed', 'mask_token', 'gamma'}, lr_scale=0)
opt_clz = {
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
'lamb': partial(optim.TimmLAMB, betas=(0.9, args.ada), max_grad_norm=args.clip),
}[args.opt]
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
# try to resume
next_ep, performance_desc = misc.load_checkpoint(args.resume, model_without_ddp, optimizer) if len(args.resume) else (0, '[no performance_desc]')
if next_ep >= args.ep:
# load from a complete checkpoint file
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
else:
# perform pre-training
start_time = time.time()
min_loss = 1e9
print(f'[PT start] from ep{next_ep}')
for ep in range(next_ep, args.ep):
if hasattr(itrt_train, 'set_epoch'):
itrt_train.set_epoch(ep)
stats, (sec, remain_time, finish_time) = pre_train_one_ep(ep, args, itrt_train, iters_train, model, optimizer)
last_loss = stats['last_loss']
min_loss = min(min_loss, last_loss)
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
print(f' [*] [ep{ep}] Min/Last Recon Loss: {performance_desc}, Remain: {remain_time}, Finish: {finish_time}')
args.cur_phase = 'PT'
args.cur_ep = f'{ep+1}/{args.ep}'
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
args.last_loss = last_loss
args.log_epoch()
misc.save_checkpoint(f'ckpt-last.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
# finish pre-training
print('\n\n')
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - start_time) / 60 / 60:.1f}h')
print('\n\n')
misc.save_checkpoint(f'ckpt-final.pth', args, args.ep-1, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
args.log_epoch()
def pre_train_one_ep(ep, args, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
model.train()
me = misc.MetricLogger(delimiter=' ')
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
header = f'[PT] Epoch: [{ep:3d}/{args.ep}]'
optimizer.zero_grad()
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
late_clipping = args.clip > 0 and hasattr(optimizer, 'global_grad_norm')
if early_clipping:
params_req_grad = [p for p in model.parameters() if p.requires_grad]
# for every batch do:
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
# adjust lr and wd
g_it = it + ep*iters_train
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, g_it, args.wp_ep*iters_train, args.ep*iters_train)
# forward and backward
inp = inp.to(args.device, non_blocking=True)
SparK.forward
active_ex, rec, loss = model(inp)
optimizer.zero_grad()
loss.backward()
loss = loss.item()
if not math.isfinite(loss):
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
sys.exit(-1)
# optimize
grad_norm = None
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip)
optimizer.step()
if late_clipping: grad_norm = optimizer.global_grad_norm
torch.cuda.synchronize()
# log
me.update(last_loss=loss)
me.update(max_lr=max_lr)
if grad_norm is not None:
me.update(orig_norm=grad_norm)
me.synchronize_between_processes()
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds((args.ep-1-ep) * (iters_train+10))
if __name__ == '__main__':
main_pt()