You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

185 lines
8.1 KiB

# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import math
import sys
import time
from functools import partial
from typing import List
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dist
import encoder
from decoder import LightDecoder
from models import build_sparse_encoder
from sampler import DistInfiniteBatchSampler, worker_init_fn
from spark import SparK
from utils import arg_util, misc, lamb
from utils.imagenet import build_dataset_to_pretrain
from utils.lr_control import lr_wd_annealing, get_param_groups
class LocalDDP(torch.nn.Module):
def __init__(self, module):
super(LocalDDP, self).__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def main_pt():
args: arg_util.Args = arg_util.init_dist_and_get_args()
print(f'initial args:\n{str(args)}')
args.log_epoch()
# build data
print(f'[build data for pre-training] ...\n')
dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size)
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn
)
itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
# build encoder and decoder
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
model_without_ddp = SparK(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
densify_norm=args.densify_norm, sbn=args.sbn,
).to(args.device)
print(f'[PT model] model = {model_without_ddp}\n')
if dist.initialized():
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
else:
model = LocalDDP(model_without_ddp)
# build optimizer and lr_scheduler
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'})
opt_clz = {
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0),
}[args.opt]
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
# try to resume
ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
if ep_start >= args.ep: # load from a complete checkpoint file
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
else: # perform pre-training
tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
min_loss = 1e9
print(f'[PT start] from ep{ep_start}')
pt_start_time = time.time()
for ep in range(ep_start, args.ep):
ep_start_time = time.time()
tb_lg.set_step(ep * iters_train)
if hasattr(itrt_train, 'set_epoch'):
itrt_train.set_epoch(ep)
stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
last_loss = stats['last_loss']
min_loss = min(min_loss, last_loss)
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
misc.save_checkpoint(f'{args.model}_still_pretraining.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
misc.save_checkpoint_for_finetune(f'{args.model}_1kpretrained.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
remain_secs = (args.ep-1 - ep) * ep_cost
remain_time = datetime.timedelta(seconds=round(remain_secs))
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
args.cur_ep = f'{ep + 1}/{args.ep}'
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
args.last_loss = last_loss
args.log_epoch()
tb_lg.update(min_loss=min_loss, head='train', step=ep)
tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
tb_lg.flush()
# finish pre-training
tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
tb_lg.flush()
print(f'final args:\n{str(args)}')
print('\n\n')
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
print('\n\n')
tb_lg.close()
time.sleep(10)
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
args.log_epoch()
def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
model.train()
me = misc.MetricLogger(delimiter=' ')
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
header = f'[PT] Epoch {ep}:'
optimizer.zero_grad()
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
late_clipping = hasattr(optimizer, 'global_grad_norm')
if early_clipping:
params_req_grad = [p for p in model.parameters() if p.requires_grad]
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
# adjust lr and wd
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
# forward and backward
inp = inp.to(args.device, non_blocking=True)
SparK.forward
loss = model(inp, active_b1ff=None, vis=False)
optimizer.zero_grad()
loss.backward()
loss = loss.item()
if not math.isfinite(loss):
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
sys.exit(-1)
# optimize
grad_norm = None
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
optimizer.step()
if late_clipping: grad_norm = optimizer.global_grad_norm
torch.cuda.synchronize()
# log
me.update(last_loss=loss)
me.update(max_lr=max_lr)
tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
if grad_norm is not None:
me.update(orig_norm=grad_norm)
tb_lg.update(orig_norm=grad_norm, head='train_hp')
tb_lg.set_step()
me.synchronize_between_processes()
return {k: meter.global_avg for k, meter in me.meters.items()}
if __name__ == '__main__':
main_pt()