You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
8.9 KiB
189 lines
8.9 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import datetime |
|
import time |
|
|
|
import torch |
|
import torch.distributed as tdist |
|
from timm.utils import ModelEmaV2 |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from arg import get_args, FineTuneArgs |
|
from models import ConvNeXt, ResNet |
|
__for_timm_registration = ConvNeXt, ResNet |
|
from lr_decay import lr_wd_annealing |
|
from util import init_distributed_environ, create_model_opt, load_checkpoint, save_checkpoint |
|
from data import create_classification_dataset |
|
|
|
|
|
def main_ft(): |
|
world_size, global_rank, local_rank, device = init_distributed_environ() |
|
args: FineTuneArgs = get_args(world_size, global_rank, local_rank, device) |
|
print(f'initial args:\n{str(args)}') |
|
args.log_epoch() |
|
|
|
criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer = create_model_opt(args) |
|
ep_start, performance_desc = load_checkpoint(args.resume_from, model_without_ddp, model_ema, optimizer) |
|
|
|
if ep_start >= args.ep: # load from a complete checkpoint file |
|
print(f' [*] [FT already done] Max/Last Acc: {performance_desc}') |
|
else: |
|
tb_lg = SummaryWriter(args.tb_lg_dir) if args.is_master else None |
|
loader_train, iters_train, iterator_val, iters_val = create_classification_dataset( |
|
args.data_path, args.img_size, args.rep_aug, |
|
args.dataloader_workers, args.batch_size_per_gpu, args.world_size, args.global_rank |
|
) |
|
|
|
# train & eval |
|
tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model) |
|
max_acc = last_acc |
|
max_acc_e = last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)[-1] |
|
print(f'[fine-tune] initial acc={last_acc:.2f}, ema={last_acc_e:.2f}') |
|
|
|
ep_eval = set(range(0, args.ep//3, 5)) | set(range(args.ep//3, args.ep)) |
|
print(f'[FT start] ep_eval={sorted(ep_eval)} ') |
|
print(f'[FT start] from ep{ep_start}') |
|
|
|
params_req_grad = [p for p in model.parameters() if p.requires_grad] |
|
ft_start_time = time.time() |
|
for ep in range(ep_start, args.ep): |
|
ep_start_time = time.time() |
|
if hasattr(loader_train, 'sampler') and hasattr(loader_train.sampler, 'set_epoch'): |
|
loader_train.sampler.set_epoch(ep) |
|
if 0 <= ep <= 3: |
|
print(f'[loader_train.sampler.set_epoch({ep})]') |
|
|
|
train_loss, train_acc = fine_tune_one_epoch(ep, args, tb_lg, loader_train, iters_train, criterion, mixup_fn, model, model_ema, optimizer, params_req_grad) |
|
if ep in ep_eval: |
|
eval_start_time = time.time() |
|
tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model) |
|
tot_pred_e, last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module) |
|
eval_cost = round(time.time() - eval_start_time, 2) |
|
performance_desc = f'Max (Last) Acc: {max(max_acc, last_acc):.2f} ({last_acc:.2f} o {tot_pred}) EMA: {max(max_acc_e, last_acc_e):.2f} ({last_acc_e:.2f} o {tot_pred_e})' |
|
states = model_without_ddp.state_dict(), model_ema.module.state_dict(), optimizer.state_dict() |
|
if last_acc > max_acc: |
|
max_acc = last_acc |
|
save_checkpoint(f'{args.model}_1kfinetuned_best.pth', args, ep, performance_desc, *states) |
|
if last_acc_e > max_acc_e: |
|
max_acc_e = last_acc_e |
|
save_checkpoint(f'{args.model}_1kfinetuned_best_ema.pth', args, ep, performance_desc, *states) |
|
save_checkpoint(f'{args.model}_1kfinetuned_last.pth', args, ep, performance_desc, *states) |
|
else: |
|
eval_cost = '-' |
|
|
|
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost |
|
remain_secs = (args.ep-1 - ep) * ep_cost |
|
remain_time = datetime.timedelta(seconds=round(remain_secs)) |
|
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)) |
|
print(f'[ep{ep}/{args.ep}] {performance_desc} Ep cost: {ep_cost}s, Ev cost: {eval_cost}, Remain: {remain_time}, Finish @ {finish_time}') |
|
args.cur_ep = f'{ep + 1}/{args.ep}' |
|
args.remain_time, args.finish_time = str(remain_time), str(finish_time) |
|
args.train_loss, args.train_acc, args.best_val_acc = train_loss, train_acc, max(max_acc, max_acc_e) |
|
args.log_epoch() |
|
|
|
if args.is_master: |
|
tb_lg.add_scalar(f'ft_train/ep_loss', train_loss, ep) |
|
tb_lg.add_scalar(f'ft_eval/max_acc', max_acc, ep) |
|
tb_lg.add_scalar(f'ft_eval/last_acc', last_acc, ep) |
|
tb_lg.add_scalar(f'ft_eval/max_acc_ema', max_acc_e, ep) |
|
tb_lg.add_scalar(f'ft_eval/last_acc_ema', last_acc_e, ep) |
|
tb_lg.add_scalar(f'ft_z_burnout/rest_hours', round(remain_secs/60/60, 2), ep) |
|
tb_lg.flush() |
|
|
|
# finish fine-tuning |
|
result_acc = max(max_acc, max_acc_e) |
|
if args.is_master: |
|
tb_lg.add_scalar('ft_result/result_acc', result_acc, ep_start) |
|
tb_lg.add_scalar('ft_result/result_acc', result_acc, args.ep) |
|
tb_lg.flush() |
|
tb_lg.close() |
|
print(f'final args:\n{str(args)}') |
|
print('\n\n') |
|
print(f' [*] [FT finished] {performance_desc} Total Cost: {(time.time() - ft_start_time) / 60 / 60:.1f}h\n') |
|
print(f' [*] [FT finished] max(max_acc, max_acc_e)={result_acc} EMA better={max_acc_e>max_acc}') |
|
print('\n\n') |
|
time.sleep(10) |
|
|
|
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time())) |
|
args.log_epoch() |
|
|
|
|
|
def fine_tune_one_epoch(ep, args: FineTuneArgs, tb_lg: SummaryWriter, loader_train, iters_train, criterion, mixup_fn, model, model_ema: ModelEmaV2, optimizer, params_req_grad): |
|
model.train() |
|
tot_loss = tot_acc = 0.0 |
|
log_freq = max(1, round(iters_train * 0.7)) |
|
ep_start_time = time.time() |
|
for it, (inp, tar) in enumerate(loader_train): |
|
# adjust lr and wd |
|
cur_it = it + ep * iters_train |
|
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, cur_it, args.wp_ep * iters_train, args.ep * iters_train) |
|
|
|
# forward |
|
inp = inp.to(args.device, non_blocking=True) |
|
raw_tar = tar = tar.to(args.device, non_blocking=True) |
|
if mixup_fn is not None: |
|
inp, tar = mixup_fn(inp, tar) |
|
oup = model(inp) |
|
pred = oup.data.argmax(dim=1) |
|
if mixup_fn is None: |
|
acc = pred.eq(tar).float().mean().item() * 100 |
|
tot_acc += acc |
|
else: |
|
acc = (pred.eq(raw_tar) | pred.eq(raw_tar.flip(0))).float().mean().item() * 100 |
|
tot_acc += acc |
|
|
|
# backward |
|
optimizer.zero_grad() |
|
loss = criterion(oup, tar) |
|
loss.backward() |
|
loss = loss.item() |
|
tot_loss += loss |
|
if args.clip > 0: |
|
orig_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item() |
|
else: |
|
orig_norm = None |
|
optimizer.step() |
|
model_ema.update(model) |
|
torch.cuda.synchronize() |
|
|
|
# log |
|
if args.is_master and cur_it % log_freq == 0: |
|
tb_lg.add_scalar(f'ft_train/it_loss', loss, cur_it) |
|
tb_lg.add_scalar(f'ft_train/it_acc', acc, cur_it) |
|
tb_lg.add_scalar(f'ft_hp/min_lr', min_lr, cur_it), tb_lg.add_scalar(f'ft_hp/max_lr', max_lr, cur_it) |
|
tb_lg.add_scalar(f'ft_hp/min_wd', min_wd, cur_it), tb_lg.add_scalar(f'ft_hp/max_wd', max_wd, cur_it) |
|
if orig_norm is not None: |
|
tb_lg.add_scalar(f'ft_hp/orig_norm', orig_norm, cur_it) |
|
|
|
if it in [3, iters_train//2, iters_train-1]: |
|
remain_secs = (iters_train-1 - it) * (time.time() - ep_start_time) / (it + 1) |
|
remain_time = datetime.timedelta(seconds=round(remain_secs)) |
|
print(f'[ep{ep} it{it:3d}/{iters_train}] L: {loss:.4f} Acc: {acc:.2f} lr: {min_lr:.1e}~{max_lr:.1e} Remain: {remain_time}') |
|
|
|
return tot_loss / iters_train, tot_acc / iters_train |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(dev, iterator_val, iters_val, model): |
|
training = model.training |
|
model.train(False) |
|
tot_pred, tot_correct = 0., 0. |
|
for _ in range(iters_val): |
|
inp, tar = next(iterator_val) |
|
tot_pred += tar.shape[0] |
|
inp = inp.to(dev, non_blocking=True) |
|
tar = tar.to(dev, non_blocking=True) |
|
oup = model(inp) |
|
tot_correct += oup.argmax(dim=1).eq(tar).sum().item() |
|
model.train(training) |
|
t = torch.tensor([tot_pred, tot_correct]).to(dev) |
|
tdist.all_reduce(t) |
|
return t[0].item(), (t[1] / t[0]).item() * 100. |
|
|
|
|
|
if __name__ == '__main__': |
|
main_ft()
|
|
|