|
|
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
import datetime
|
|
|
|
import math
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
from functools import partial
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
import dist
|
|
|
|
import encoder
|
|
|
|
from decoder import LightDecoder
|
|
|
|
from models import build_sparse_encoder
|
|
|
|
from sampler import DistInfiniteBatchSampler, worker_init_fn
|
|
|
|
from spark import SparK
|
|
|
|
from utils import arg_util, misc, lamb
|
|
|
|
from utils.imagenet import build_imagenet_pretrain
|
|
|
|
from utils.lr_control import lr_wd_annealing, get_param_groups
|
|
|
|
|
|
|
|
|
|
|
|
def main_pt():
|
|
|
|
args: arg_util.Args = arg_util.init_dist_and_get_args()
|
|
|
|
print(f'initial args:\n{str(args)}')
|
|
|
|
args.log_epoch()
|
|
|
|
|
|
|
|
# build data
|
|
|
|
print(f'[build data for pre-training] ...\n')
|
|
|
|
dataset_train = build_imagenet_pretrain(args.data_path, args.input_size)
|
|
|
|
data_loader_train = DataLoader(
|
|
|
|
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
|
|
|
|
batch_sampler=DistInfiniteBatchSampler(
|
|
|
|
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed,
|
|
|
|
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
|
|
|
|
), worker_init_fn=worker_init_fn
|
|
|
|
)
|
|
|
|
itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
|
|
|
|
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
|
|
|
|
|
|
|
|
# build encoder and decoder
|
|
|
|
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
|
|
|
|
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
|
|
|
|
model_without_ddp = SparK(
|
|
|
|
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
|
|
|
|
densify_norm=args.densify_norm, sbn=args.sbn, hierarchy=args.hierarchy,
|
|
|
|
).to(args.device)
|
|
|
|
print(f'[PT model] model = {model_without_ddp}\n')
|
|
|
|
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
|
|
|
|
|
|
|
# build optimizer and lr_scheduler
|
|
|
|
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'})
|
|
|
|
opt_clz = {
|
|
|
|
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
|
|
|
|
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
|
|
|
|
'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0),
|
|
|
|
}[args.opt]
|
|
|
|
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
|
|
|
|
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
|
|
|
|
|
|
|
|
# try to resume
|
|
|
|
ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
|
|
|
|
if ep_start >= args.ep: # load from a complete checkpoint file
|
|
|
|
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
|
|
|
|
else: # perform pre-training
|
|
|
|
tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
|
|
|
|
min_loss = 1e9
|
|
|
|
print(f'[PT start] from ep{ep_start}')
|
|
|
|
|
|
|
|
pt_start_time = time.time()
|
|
|
|
for ep in range(ep_start, args.ep):
|
|
|
|
ep_start_time = time.time()
|
|
|
|
tb_lg.set_step(ep * iters_train)
|
|
|
|
if hasattr(itrt_train, 'set_epoch'):
|
|
|
|
itrt_train.set_epoch(ep)
|
|
|
|
|
|
|
|
stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
|
|
|
|
last_loss = stats['last_loss']
|
|
|
|
min_loss = min(min_loss, last_loss)
|
|
|
|
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
|
|
|
|
misc.save_checkpoint(f'{args.model}_still_pretraining.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
|
|
|
|
misc.save_checkpoint_for_finetune(f'{args.model}_1kpretrained.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())
|
|
|
|
|
|
|
|
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
|
|
|
|
remain_secs = (args.ep-1 - ep) * ep_cost
|
|
|
|
remain_time = datetime.timedelta(seconds=round(remain_secs))
|
|
|
|
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
|
|
|
|
print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
|
|
|
|
|
|
|
|
args.cur_ep = f'{ep + 1}/{args.ep}'
|
|
|
|
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
|
|
|
|
args.last_loss = last_loss
|
|
|
|
args.log_epoch()
|
|
|
|
|
|
|
|
tb_lg.update(min_loss=min_loss, head='train', step=ep)
|
|
|
|
tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
|
|
|
|
tb_lg.flush()
|
|
|
|
|
|
|
|
# finish pre-training
|
|
|
|
tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
|
|
|
|
tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
|
|
|
|
tb_lg.flush()
|
|
|
|
print(f'final args:\n{str(args)}')
|
|
|
|
print('\n\n')
|
|
|
|
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
|
|
|
|
print('\n\n')
|
|
|
|
tb_lg.close()
|
|
|
|
time.sleep(10)
|
|
|
|
|
|
|
|
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
|
|
|
|
args.log_epoch()
|
|
|
|
|
|
|
|
|
|
|
|
def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
|
|
|
|
model.train()
|
|
|
|
me = misc.MetricLogger(delimiter=' ')
|
|
|
|
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
|
|
|
|
header = f'[PT] Epoch {ep}:'
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
|
|
|
|
late_clipping = hasattr(optimizer, 'global_grad_norm')
|
|
|
|
if early_clipping:
|
|
|
|
params_req_grad = [p for p in model.parameters() if p.requires_grad]
|
|
|
|
|
|
|
|
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
|
|
|
|
# adjust lr and wd
|
|
|
|
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
|
|
|
|
|
|
|
|
# forward and backward
|
|
|
|
inp = inp.to(args.device, non_blocking=True)
|
|
|
|
SparK.forward
|
|
|
|
_, _, loss = model(inp)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
loss = loss.item()
|
|
|
|
if not math.isfinite(loss):
|
|
|
|
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
# optimize
|
|
|
|
grad_norm = None
|
|
|
|
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
|
|
|
|
optimizer.step()
|
|
|
|
if late_clipping: grad_norm = optimizer.global_grad_norm
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
# log
|
|
|
|
me.update(last_loss=loss)
|
|
|
|
me.update(max_lr=max_lr)
|
|
|
|
tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
|
|
|
|
tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
|
|
|
|
tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
|
|
|
|
tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
|
|
|
|
tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
|
|
|
|
|
|
|
|
if grad_norm is not None:
|
|
|
|
me.update(orig_norm=grad_norm)
|
|
|
|
tb_lg.update(orig_norm=grad_norm, head='train_hp')
|
|
|
|
tb_lg.set_step()
|
|
|
|
|
|
|
|
me.synchronize_between_processes()
|
|
|
|
return {k: meter.global_avg for k, meter in me.meters.items()}
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main_pt()
|