You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
8.1 KiB
185 lines
8.1 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import datetime |
|
import math |
|
import sys |
|
import time |
|
from functools import partial |
|
from typing import List |
|
|
|
import torch |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.utils.data import DataLoader |
|
|
|
import dist |
|
import encoder |
|
from decoder import LightDecoder |
|
from models import build_sparse_encoder |
|
from sampler import DistInfiniteBatchSampler, worker_init_fn |
|
from spark import SparK |
|
from utils import arg_util, misc, lamb |
|
from utils.imagenet import build_dataset_to_pretrain |
|
from utils.lr_control import lr_wd_annealing, get_param_groups |
|
|
|
|
|
class LocalDDP(torch.nn.Module): |
|
def __init__(self, module): |
|
super(LocalDDP, self).__init__() |
|
self.module = module |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.module(*args, **kwargs) |
|
|
|
|
|
def main_pt(): |
|
args: arg_util.Args = arg_util.init_dist_and_get_args() |
|
print(f'initial args:\n{str(args)}') |
|
args.log_epoch() |
|
|
|
# build data |
|
print(f'[build data for pre-training] ...\n') |
|
dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size) |
|
data_loader_train = DataLoader( |
|
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True, |
|
batch_sampler=DistInfiniteBatchSampler( |
|
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, |
|
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(), |
|
), worker_init_fn=worker_init_fn |
|
) |
|
itrt_train, iters_train = iter(data_loader_train), len(data_loader_train) |
|
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}') |
|
|
|
# build encoder and decoder |
|
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False) |
|
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn) |
|
model_without_ddp = SparK( |
|
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, |
|
densify_norm=args.densify_norm, sbn=args.sbn, |
|
).to(args.device) |
|
print(f'[PT model] model = {model_without_ddp}\n') |
|
if dist.initialized(): |
|
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) |
|
else: |
|
model = LocalDDP(model_without_ddp) |
|
|
|
# build optimizer and lr_scheduler |
|
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}) |
|
opt_clz = { |
|
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True), |
|
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)), |
|
'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0), |
|
}[args.opt] |
|
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0) |
|
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n') |
|
|
|
# try to resume |
|
ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer) |
|
if ep_start >= args.ep: # load from a complete checkpoint file |
|
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}') |
|
else: # perform pre-training |
|
tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt') |
|
min_loss = 1e9 |
|
print(f'[PT start] from ep{ep_start}') |
|
|
|
pt_start_time = time.time() |
|
for ep in range(ep_start, args.ep): |
|
ep_start_time = time.time() |
|
tb_lg.set_step(ep * iters_train) |
|
if hasattr(itrt_train, 'set_epoch'): |
|
itrt_train.set_epoch(ep) |
|
|
|
stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer) |
|
last_loss = stats['last_loss'] |
|
min_loss = min(min_loss, last_loss) |
|
performance_desc = f'{min_loss:.4f} {last_loss:.4f}' |
|
misc.save_checkpoint(f'{args.model}_still_pretraining.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict()) |
|
misc.save_checkpoint_for_finetune(f'{args.model}_1kpretrained.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict()) |
|
|
|
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost |
|
remain_secs = (args.ep-1 - ep) * ep_cost |
|
remain_time = datetime.timedelta(seconds=round(remain_secs)) |
|
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)) |
|
print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}') |
|
|
|
args.cur_ep = f'{ep + 1}/{args.ep}' |
|
args.remain_time, args.finish_time = str(remain_time), str(finish_time) |
|
args.last_loss = last_loss |
|
args.log_epoch() |
|
|
|
tb_lg.update(min_loss=min_loss, head='train', step=ep) |
|
tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep) |
|
tb_lg.flush() |
|
|
|
# finish pre-training |
|
tb_lg.update(min_loss=min_loss, head='result', step=ep_start) |
|
tb_lg.update(min_loss=min_loss, head='result', step=args.ep) |
|
tb_lg.flush() |
|
print(f'final args:\n{str(args)}') |
|
print('\n\n') |
|
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n') |
|
print('\n\n') |
|
tb_lg.close() |
|
time.sleep(10) |
|
|
|
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time())) |
|
args.log_epoch() |
|
|
|
|
|
def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer): |
|
model.train() |
|
me = misc.MetricLogger(delimiter=' ') |
|
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}')) |
|
header = f'[PT] Epoch {ep}:' |
|
|
|
optimizer.zero_grad() |
|
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
|
late_clipping = hasattr(optimizer, 'global_grad_norm') |
|
if early_clipping: |
|
params_req_grad = [p for p in model.parameters() if p.requires_grad] |
|
|
|
for it, (inp, _) in enumerate(me.log_every(iters_train, itrt_train, 3, header)): |
|
# adjust lr and wd |
|
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train) |
|
|
|
# forward and backward |
|
inp = inp.to(args.device, non_blocking=True) |
|
SparK.forward |
|
loss = model(inp, active_b1ff=None, vis=False) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
loss = loss.item() |
|
if not math.isfinite(loss): |
|
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True) |
|
sys.exit(-1) |
|
|
|
# optimize |
|
grad_norm = None |
|
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item() |
|
optimizer.step() |
|
if late_clipping: grad_norm = optimizer.global_grad_norm |
|
torch.cuda.synchronize() |
|
|
|
# log |
|
me.update(last_loss=loss) |
|
me.update(max_lr=max_lr) |
|
tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss') |
|
tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max') |
|
tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min') |
|
tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max') |
|
tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min') |
|
|
|
if grad_norm is not None: |
|
me.update(orig_norm=grad_norm) |
|
tb_lg.update(orig_norm=grad_norm, head='train_hp') |
|
tb_lg.set_step() |
|
|
|
me.synchronize_between_processes() |
|
return {k: meter.global_avg for k, meter in me.meters.items()} |
|
|
|
|
|
if __name__ == '__main__': |
|
main_pt()
|
|
|