You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

189 lines
8.9 KiB

# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import time
import torch
import torch.distributed as tdist
from timm.utils import ModelEmaV2
from torch.utils.tensorboard import SummaryWriter
from arg import get_args, FineTuneArgs
from models import ConvNeXt, ResNet
__for_timm_registration = ConvNeXt, ResNet
from lr_decay import lr_wd_annealing
from util import init_distributed_environ, create_model_opt, load_checkpoint, save_checkpoint
from data import create_classification_dataset
def main_ft():
world_size, global_rank, local_rank, device = init_distributed_environ()
args: FineTuneArgs = get_args(world_size, global_rank, local_rank, device)
print(f'initial args:\n{str(args)}')
args.log_epoch()
criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer = create_model_opt(args)
ep_start, performance_desc = load_checkpoint(args.resume_from, model_without_ddp, model_ema, optimizer)
if ep_start >= args.ep: # load from a complete checkpoint file
print(f' [*] [FT already done] Max/Last Acc: {performance_desc}')
else:
tb_lg = SummaryWriter(args.tb_lg_dir) if args.is_master else None
loader_train, iters_train, iterator_val, iters_val = create_classification_dataset(
args.data_path, args.img_size, args.rep_aug,
args.dataloader_workers, args.batch_size_per_gpu, args.world_size, args.global_rank
)
# train & eval
tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
max_acc = last_acc
max_acc_e = last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)[-1]
print(f'[fine-tune] initial acc={last_acc:.2f}, ema={last_acc_e:.2f}')
ep_eval = set(range(0, args.ep//3, 5)) | set(range(args.ep//3, args.ep))
print(f'[FT start] ep_eval={sorted(ep_eval)} ')
print(f'[FT start] from ep{ep_start}')
params_req_grad = [p for p in model.parameters() if p.requires_grad]
ft_start_time = time.time()
for ep in range(ep_start, args.ep):
ep_start_time = time.time()
if hasattr(loader_train, 'sampler') and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(ep)
if 0 <= ep <= 3:
print(f'[loader_train.sampler.set_epoch({ep})]')
train_loss, train_acc = fine_tune_one_epoch(ep, args, tb_lg, loader_train, iters_train, criterion, mixup_fn, model, model_ema, optimizer, params_req_grad)
if ep in ep_eval:
eval_start_time = time.time()
tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
tot_pred_e, last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)
eval_cost = round(time.time() - eval_start_time, 2)
performance_desc = f'Max (Last) Acc: {max(max_acc, last_acc):.2f} ({last_acc:.2f} o {tot_pred}) EMA: {max(max_acc_e, last_acc_e):.2f} ({last_acc_e:.2f} o {tot_pred_e})'
states = model_without_ddp.state_dict(), model_ema.module.state_dict(), optimizer.state_dict()
if last_acc > max_acc:
max_acc = last_acc
save_checkpoint(f'{args.model}_1kfinetuned_best.pth', args, ep, performance_desc, *states)
if last_acc_e > max_acc_e:
max_acc_e = last_acc_e
save_checkpoint(f'{args.model}_1kfinetuned_best_ema.pth', args, ep, performance_desc, *states)
save_checkpoint(f'{args.model}_1kfinetuned_last.pth', args, ep, performance_desc, *states)
else:
eval_cost = '-'
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
remain_secs = (args.ep-1 - ep) * ep_cost
remain_time = datetime.timedelta(seconds=round(remain_secs))
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
print(f'[ep{ep}/{args.ep}] {performance_desc} Ep cost: {ep_cost}s, Ev cost: {eval_cost}, Remain: {remain_time}, Finish @ {finish_time}')
args.cur_ep = f'{ep + 1}/{args.ep}'
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
args.train_loss, args.train_acc, args.best_val_acc = train_loss, train_acc, max(max_acc, max_acc_e)
args.log_epoch()
if args.is_master:
tb_lg.add_scalar(f'ft_train/ep_loss', train_loss, ep)
tb_lg.add_scalar(f'ft_eval/max_acc', max_acc, ep)
tb_lg.add_scalar(f'ft_eval/last_acc', last_acc, ep)
tb_lg.add_scalar(f'ft_eval/max_acc_ema', max_acc_e, ep)
tb_lg.add_scalar(f'ft_eval/last_acc_ema', last_acc_e, ep)
tb_lg.add_scalar(f'ft_z_burnout/rest_hours', round(remain_secs/60/60, 2), ep)
tb_lg.flush()
# finish fine-tuning
result_acc = max(max_acc, max_acc_e)
if args.is_master:
tb_lg.add_scalar('ft_result/result_acc', result_acc, ep_start)
tb_lg.add_scalar('ft_result/result_acc', result_acc, args.ep)
tb_lg.flush()
tb_lg.close()
print(f'final args:\n{str(args)}')
print('\n\n')
print(f' [*] [FT finished] {performance_desc} Total Cost: {(time.time() - ft_start_time) / 60 / 60:.1f}h\n')
print(f' [*] [FT finished] max(max_acc, max_acc_e)={result_acc} EMA better={max_acc_e>max_acc}')
print('\n\n')
time.sleep(10)
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
args.log_epoch()
def fine_tune_one_epoch(ep, args: FineTuneArgs, tb_lg: SummaryWriter, loader_train, iters_train, criterion, mixup_fn, model, model_ema: ModelEmaV2, optimizer, params_req_grad):
model.train()
tot_loss = tot_acc = 0.0
log_freq = max(1, round(iters_train * 0.7))
ep_start_time = time.time()
for it, (inp, tar) in enumerate(loader_train):
# adjust lr and wd
cur_it = it + ep * iters_train
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, cur_it, args.wp_ep * iters_train, args.ep * iters_train)
# forward
inp = inp.to(args.device, non_blocking=True)
raw_tar = tar = tar.to(args.device, non_blocking=True)
if mixup_fn is not None:
inp, tar = mixup_fn(inp, tar)
oup = model(inp)
pred = oup.data.argmax(dim=1)
if mixup_fn is None:
acc = pred.eq(tar).float().mean().item() * 100
tot_acc += acc
else:
acc = (pred.eq(raw_tar) | pred.eq(raw_tar.flip(0))).float().mean().item() * 100
tot_acc += acc
# backward
optimizer.zero_grad()
loss = criterion(oup, tar)
loss.backward()
loss = loss.item()
tot_loss += loss
if args.clip > 0:
orig_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
else:
orig_norm = None
optimizer.step()
model_ema.update(model)
torch.cuda.synchronize()
# log
if args.is_master and cur_it % log_freq == 0:
tb_lg.add_scalar(f'ft_train/it_loss', loss, cur_it)
tb_lg.add_scalar(f'ft_train/it_acc', acc, cur_it)
tb_lg.add_scalar(f'ft_hp/min_lr', min_lr, cur_it), tb_lg.add_scalar(f'ft_hp/max_lr', max_lr, cur_it)
tb_lg.add_scalar(f'ft_hp/min_wd', min_wd, cur_it), tb_lg.add_scalar(f'ft_hp/max_wd', max_wd, cur_it)
if orig_norm is not None:
tb_lg.add_scalar(f'ft_hp/orig_norm', orig_norm, cur_it)
if it in [3, iters_train//2, iters_train-1]:
remain_secs = (iters_train-1 - it) * (time.time() - ep_start_time) / (it + 1)
remain_time = datetime.timedelta(seconds=round(remain_secs))
print(f'[ep{ep} it{it:3d}/{iters_train}] L: {loss:.4f} Acc: {acc:.2f} lr: {min_lr:.1e}~{max_lr:.1e} Remain: {remain_time}')
return tot_loss / iters_train, tot_acc / iters_train
@torch.no_grad()
def evaluate(dev, iterator_val, iters_val, model):
training = model.training
model.train(False)
tot_pred, tot_correct = 0., 0.
for _ in range(iters_val):
inp, tar = next(iterator_val)
tot_pred += tar.shape[0]
inp = inp.to(dev, non_blocking=True)
tar = tar.to(dev, non_blocking=True)
oup = model(inp)
tot_correct += oup.argmax(dim=1).eq(tar).sum().item()
model.train(training)
t = torch.tensor([tot_pred, tot_correct]).to(dev)
tdist.all_reduce(t)
return t[0].item(), (t[1] / t[0]).item() * 100.
if __name__ == '__main__':
main_ft()