You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
6.9 KiB
156 lines
6.9 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import math |
|
import sys |
|
import time |
|
from functools import partial |
|
from typing import List |
|
|
|
import torch |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.utils.data import DataLoader |
|
|
|
import dist |
|
import encoder |
|
from decoder import LightDecoder |
|
from models import build_sparse_encoder |
|
from sampler import DistInfiniteBatchSampler, worker_init_fn |
|
from spark import SparK |
|
from utils import meta, misc, optim |
|
from utils.imagenet import build_imagenet |
|
from utils.lr_control import lr_wd_annealing, get_param_groups |
|
|
|
|
|
def main_pt(): |
|
args: meta.Args = meta.init_dist_and_get_args() |
|
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}') |
|
print(f'initial args:\n{str(args)}') |
|
args.log_epoch() |
|
|
|
# build data |
|
print(f'[build data for pre-training] ...\n') |
|
dataset_train, _ = build_imagenet('pt', args.data_path, args.data_set, args.input_size, eval_crop_pct=None, rrc=args.rrc) |
|
data_loader_train = DataLoader( |
|
dataset=dataset_train, num_workers=args.num_workers, pin_memory=True, |
|
batch_sampler=DistInfiniteBatchSampler( |
|
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed, |
|
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(), |
|
), worker_init_fn=worker_init_fn |
|
) |
|
itrt_train = iter(data_loader_train) |
|
iters_train = len(data_loader_train) |
|
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}') |
|
|
|
# build models (encoder, decoder, and other components) |
|
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False) |
|
dec = LightDecoder(args.dec_dim, enc.downsample_raito, double=args.double, heavy=args.hea, cmid=args.cmid, sbn=args.sbn) |
|
spark = SparK( |
|
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, mask_ratio2=args.mask2, uniform=args.uni, |
|
using_pe=args.pe, pix_norm=args.pn, dense_loss=args.den, loss_l2=args.loss_l2, |
|
en_de_norm=args.en_de_norm, en_de_lin=args.en_de_lin, sbn=args.sbn, pyramid=args.py, |
|
) |
|
print(f'[PT model] model = {spark}\n') |
|
spark.to(args.device) |
|
model: DistributedDataParallel = DistributedDataParallel(spark, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
|
model_without_ddp: SparK = model.module |
|
|
|
# build optimizer and lr_scheduler |
|
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'pos_embed', 'mask_token', 'gamma'}, lr_scale=0) |
|
opt_clz = { |
|
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True), |
|
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)), |
|
'lamb': partial(optim.TimmLAMB, betas=(0.9, args.ada), max_grad_norm=args.clip), |
|
}[args.opt] |
|
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0) |
|
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n') |
|
|
|
# try to resume |
|
next_ep, performance_desc = misc.load_checkpoint(args.resume, model_without_ddp, optimizer) if len(args.resume) else (0, '[no performance_desc]') |
|
if next_ep >= args.ep: |
|
# load from a complete checkpoint file |
|
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}') |
|
else: |
|
# perform pre-training |
|
start_time = time.time() |
|
min_loss = 1e9 |
|
print(f'[PT start] from ep{next_ep}') |
|
|
|
for ep in range(next_ep, args.ep): |
|
if hasattr(itrt_train, 'set_epoch'): |
|
itrt_train.set_epoch(ep) |
|
|
|
stats, (sec, remain_time, finish_time) = pre_train_one_ep(ep, args, itrt_train, iters_train, model, optimizer) |
|
last_loss = stats['last_loss'] |
|
min_loss = min(min_loss, last_loss) |
|
performance_desc = f'{min_loss:.4f} {last_loss:.4f}' |
|
print(f' [*] [ep{ep}] Min/Last Recon Loss: {performance_desc}, Remain: {remain_time}, Finish: {finish_time}') |
|
|
|
args.cur_phase = 'PT' |
|
args.cur_ep = f'{ep+1}/{args.ep}' |
|
args.remain_time, args.finish_time = str(remain_time), str(finish_time) |
|
args.last_loss = last_loss |
|
args.log_epoch() |
|
misc.save_checkpoint(f'ckpt-last.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) |
|
|
|
# finish pre-training |
|
print('\n\n') |
|
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - start_time) / 60 / 60:.1f}h') |
|
print('\n\n') |
|
|
|
misc.save_checkpoint(f'ckpt-final.pth', args, args.ep-1, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) |
|
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time())) |
|
args.log_epoch() |
|
|
|
|
|
def pre_train_one_ep(ep, args, itrt_train, iters_train, model: DistributedDataParallel, optimizer): |
|
model.train() |
|
me = misc.MetricLogger(delimiter=' ') |
|
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}')) |
|
header = f'[PT] Epoch: [{ep:3d}/{args.ep}]' |
|
|
|
optimizer.zero_grad() |
|
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
|
late_clipping = args.clip > 0 and hasattr(optimizer, 'global_grad_norm') |
|
if early_clipping: |
|
params_req_grad = [p for p in model.parameters() if p.requires_grad] |
|
|
|
# for every batch do: |
|
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)): |
|
# adjust lr and wd |
|
g_it = it + ep*iters_train |
|
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, g_it, args.wp_ep*iters_train, args.ep*iters_train) |
|
|
|
# forward and backward |
|
inp = inp.to(args.device, non_blocking=True) |
|
SparK.forward |
|
active_ex, rec, loss = model(inp) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
loss = loss.item() |
|
if not math.isfinite(loss): |
|
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True) |
|
sys.exit(-1) |
|
|
|
# optimize |
|
grad_norm = None |
|
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip) |
|
optimizer.step() |
|
if late_clipping: grad_norm = optimizer.global_grad_norm |
|
torch.cuda.synchronize() |
|
|
|
# log |
|
me.update(last_loss=loss) |
|
me.update(max_lr=max_lr) |
|
if grad_norm is not None: |
|
me.update(orig_norm=grad_norm) |
|
|
|
me.synchronize_between_processes() |
|
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds((args.ep-1-ep) * (iters_train+10)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main_pt()
|
|
|