You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
5.9 KiB
132 lines
5.9 KiB
2 years ago
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import datetime
|
||
|
import os
|
||
|
import sys
|
||
|
from functools import partial
|
||
|
from typing import List, Tuple, Callable
|
||
|
|
||
|
import pytz
|
||
|
import torch
|
||
|
import torch.distributed as tdist
|
||
|
import torch.multiprocessing as tmp
|
||
|
from timm import create_model
|
||
|
from timm.data import Mixup
|
||
|
from timm.loss import SoftTargetCrossEntropy, BinaryCrossEntropy
|
||
|
from timm.optim import AdamW, Lamb
|
||
|
from timm.utils import ModelEmaV2
|
||
|
from torch.nn.parallel import DistributedDataParallel
|
||
|
from torch.optim.optimizer import Optimizer
|
||
|
|
||
|
from arg import FineTuneArgs
|
||
|
from lr_decay import get_param_groups
|
||
|
|
||
|
|
||
|
def time_str(for_dirname=False):
|
||
|
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')
|
||
|
|
||
|
|
||
|
def init_distributed_environ():
|
||
|
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
||
|
if tmp.get_start_method(allow_none=True) is None:
|
||
|
tmp.set_start_method('spawn')
|
||
|
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
|
||
|
local_rank = global_rank % num_gpus
|
||
|
torch.cuda.set_device(local_rank)
|
||
|
|
||
|
tdist.init_process_group(backend='nccl')
|
||
|
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
||
|
torch.backends.cudnn.benchmark = True
|
||
|
torch.backends.cudnn.deterministic = False
|
||
|
|
||
|
# print only when local_rank == 0 or print(..., force=True)
|
||
|
import builtins as __builtin__
|
||
|
builtin_print = __builtin__.print
|
||
|
|
||
|
def prt(msg, *args, **kwargs):
|
||
|
force = kwargs.pop('force', False)
|
||
|
if local_rank == 0 or force:
|
||
|
f_back = sys._getframe().f_back
|
||
|
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
||
|
builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}', *args, **kwargs)
|
||
|
|
||
|
__builtin__.print = prt
|
||
|
tdist.barrier()
|
||
|
return tdist.get_world_size(), global_rank, local_rank, torch.empty(1).cuda().device
|
||
|
|
||
|
|
||
|
def create_model_opt(args: FineTuneArgs) -> Tuple[torch.nn.Module, Callable, torch.nn.Module, DistributedDataParallel, ModelEmaV2, Optimizer]:
|
||
|
num_classes = 1000
|
||
|
model_without_ddp: torch.nn.Module = create_model(args.model, num_classes=num_classes, drop_path_rate=args.drop_path).to(args.device)
|
||
|
model_para = f'{sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1e6:.1f}M'
|
||
|
# create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||
|
model_ema = ModelEmaV2(model_without_ddp, decay=args.ema, device=args.device)
|
||
|
if args.sbn:
|
||
|
model_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_without_ddp)
|
||
|
print(f'[model={args.model}] [#para={model_para}, drop_path={args.drop_path}, ema={args.ema}] {model_without_ddp}\n')
|
||
|
model = DistributedDataParallel(model_without_ddp, device_ids=[args.local_rank], find_unused_parameters=False, broadcast_buffers=False)
|
||
|
model.train()
|
||
|
opt_cls = {
|
||
|
'adam': AdamW, 'adamw': AdamW,
|
||
|
'lamb': partial(Lamb, max_grad_norm=1e7, always_adapt=True, bias_correction=False),
|
||
|
}
|
||
|
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}, lr_scale=args.lr_scale)
|
||
|
# param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float}
|
||
|
optimizer = opt_cls[args.opt](param_groups, lr=args.lr, weight_decay=0)
|
||
|
print(f'[optimizer={type(optimizer)}]')
|
||
|
mixup_fn = Mixup(
|
||
|
mixup_alpha=args.mixup, cutmix_alpha=1.0, cutmix_minmax=None,
|
||
|
prob=1.0, switch_prob=0.5, mode='batch',
|
||
|
label_smoothing=0.1, num_classes=num_classes
|
||
|
)
|
||
|
mixup_fn.mixup_enabled = args.mixup > 0.0
|
||
|
if 'lamb' in args.opt:
|
||
|
# label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0
|
||
|
criterion = BinaryCrossEntropy(smoothing=0, target_threshold=None)
|
||
|
else:
|
||
|
criterion = SoftTargetCrossEntropy()
|
||
|
print(f'[loss_fn] {criterion}')
|
||
|
print(f'[mixup_fn] {mixup_fn}')
|
||
|
return criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer
|
||
|
|
||
|
|
||
|
def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer):
|
||
|
if len(resume_from) == 0 or not os.path.exists(resume_from):
|
||
|
raise AttributeError(f'ckpt `{resume_from}` not found!')
|
||
|
# return 0, '[no performance_desc]'
|
||
|
print(f'[try to resume from file `{resume_from}`]')
|
||
|
checkpoint = torch.load(resume_from, map_location='cpu')
|
||
|
assert checkpoint.get('is_pretrain', False) == False, 'please do not use `PT-xxxx-.pth`; it is only for pretraining'
|
||
|
|
||
|
ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]')
|
||
|
missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
|
||
|
print(f'[load_checkpoint] missing_keys={missing}')
|
||
|
print(f'[load_checkpoint] unexpected_keys={unexpected}')
|
||
|
print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}')
|
||
|
|
||
|
if 'optimizer' in checkpoint:
|
||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||
|
if 'ema' in checkpoint:
|
||
|
ema_module.load_state_dict(checkpoint['ema'])
|
||
|
return ep_start, performance_desc
|
||
|
|
||
|
|
||
|
def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_state, ema_state, optimizer_state):
|
||
|
checkpoint_path = os.path.join(args.exp_dir, save_to)
|
||
|
if args.is_local_master:
|
||
|
to_save = {
|
||
|
'args': str(args),
|
||
|
'arch': args.model,
|
||
|
'epoch': epoch,
|
||
|
'performance_desc': performance_desc,
|
||
|
'module': model_without_ddp_state,
|
||
|
'ema': ema_state,
|
||
|
'optimizer': optimizer_state,
|
||
|
'is_pretrain': False,
|
||
|
}
|
||
|
torch.save(to_save, checkpoint_path)
|