You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
5.9 KiB
131 lines
5.9 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import datetime |
|
import os |
|
import sys |
|
from functools import partial |
|
from typing import List, Tuple, Callable |
|
|
|
import pytz |
|
import torch |
|
import torch.distributed as tdist |
|
import torch.multiprocessing as tmp |
|
from timm import create_model |
|
from timm.data import Mixup |
|
from timm.loss import SoftTargetCrossEntropy, BinaryCrossEntropy |
|
from timm.optim import AdamW, Lamb |
|
from timm.utils import ModelEmaV2 |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.optim.optimizer import Optimizer |
|
|
|
from arg import FineTuneArgs |
|
from lr_decay import get_param_groups |
|
|
|
|
|
def time_str(for_dirname=False): |
|
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]') |
|
|
|
|
|
def init_distributed_environ(): |
|
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 |
|
if tmp.get_start_method(allow_none=True) is None: |
|
tmp.set_start_method('spawn') |
|
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count() |
|
local_rank = global_rank % num_gpus |
|
torch.cuda.set_device(local_rank) |
|
|
|
tdist.init_process_group(backend='nccl') |
|
assert tdist.is_initialized(), 'torch.distributed is not initialized!' |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
|
|
# print only when local_rank == 0 or print(..., force=True) |
|
import builtins as __builtin__ |
|
builtin_print = __builtin__.print |
|
|
|
def prt(msg, *args, **kwargs): |
|
force = kwargs.pop('force', False) |
|
if local_rank == 0 or force: |
|
f_back = sys._getframe().f_back |
|
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] |
|
builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}', *args, **kwargs) |
|
|
|
__builtin__.print = prt |
|
tdist.barrier() |
|
return tdist.get_world_size(), global_rank, local_rank, torch.empty(1).cuda().device |
|
|
|
|
|
def create_model_opt(args: FineTuneArgs) -> Tuple[torch.nn.Module, Callable, torch.nn.Module, DistributedDataParallel, ModelEmaV2, Optimizer]: |
|
num_classes = 1000 |
|
model_without_ddp: torch.nn.Module = create_model(args.model, num_classes=num_classes, drop_path_rate=args.drop_path).to(args.device) |
|
model_para = f'{sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1e6:.1f}M' |
|
# create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper |
|
model_ema = ModelEmaV2(model_without_ddp, decay=args.ema, device=args.device) |
|
if args.sbn: |
|
model_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_without_ddp) |
|
print(f'[model={args.model}] [#para={model_para}, drop_path={args.drop_path}, ema={args.ema}] {model_without_ddp}\n') |
|
model = DistributedDataParallel(model_without_ddp, device_ids=[args.local_rank], find_unused_parameters=False, broadcast_buffers=False) |
|
model.train() |
|
opt_cls = { |
|
'adam': AdamW, 'adamw': AdamW, |
|
'lamb': partial(Lamb, max_grad_norm=1e7, always_adapt=True, bias_correction=False), |
|
} |
|
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}, lr_scale=args.lr_scale) |
|
# param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float} |
|
optimizer = opt_cls[args.opt](param_groups, lr=args.lr, weight_decay=0) |
|
print(f'[optimizer={type(optimizer)}]') |
|
mixup_fn = Mixup( |
|
mixup_alpha=args.mixup, cutmix_alpha=1.0, cutmix_minmax=None, |
|
prob=1.0, switch_prob=0.5, mode='batch', |
|
label_smoothing=0.1, num_classes=num_classes |
|
) |
|
mixup_fn.mixup_enabled = args.mixup > 0.0 |
|
if 'lamb' in args.opt: |
|
# label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0 |
|
criterion = BinaryCrossEntropy(smoothing=0, target_threshold=None) |
|
else: |
|
criterion = SoftTargetCrossEntropy() |
|
print(f'[loss_fn] {criterion}') |
|
print(f'[mixup_fn] {mixup_fn}') |
|
return criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer |
|
|
|
|
|
def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer): |
|
if len(resume_from) == 0 or not os.path.exists(resume_from): |
|
raise AttributeError(f'ckpt `{resume_from}` not found!') |
|
# return 0, '[no performance_desc]' |
|
print(f'[try to resume from file `{resume_from}`]') |
|
checkpoint = torch.load(resume_from, map_location='cpu') |
|
assert checkpoint.get('is_pretrain', False) == False, 'please do not use `PT-xxxx-.pth`; it is only for pretraining' |
|
|
|
ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]') |
|
missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False) |
|
print(f'[load_checkpoint] missing_keys={missing}') |
|
print(f'[load_checkpoint] unexpected_keys={unexpected}') |
|
print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}') |
|
|
|
if 'optimizer' in checkpoint: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
if 'ema' in checkpoint: |
|
ema_module.load_state_dict(checkpoint['ema']) |
|
return ep_start, performance_desc |
|
|
|
|
|
def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_state, ema_state, optimizer_state): |
|
checkpoint_path = os.path.join(args.exp_dir, save_to) |
|
if args.is_local_master: |
|
to_save = { |
|
'args': str(args), |
|
'arch': args.model, |
|
'epoch': epoch, |
|
'performance_desc': performance_desc, |
|
'module': model_without_ddp_state, |
|
'ema': ema_state, |
|
'optimizer': optimizer_state, |
|
'is_pretrain': False, |
|
} |
|
torch.save(to_save, checkpoint_path)
|
|
|