# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import os
import sys
from functools import partial
from typing import List , Tuple , Callable
import pytz
import torch
import torch . distributed as tdist
import torch . multiprocessing as tmp
from timm import create_model
from timm . loss import SoftTargetCrossEntropy , BinaryCrossEntropy
from timm . optim import AdamW , Lamb
from timm . utils import ModelEmaV2
from torch . nn . parallel import DistributedDataParallel
from torch . optim . optimizer import Optimizer
from arg import FineTuneArgs
from downstream_imagenet . mixup import BatchMixup
from lr_decay import get_param_groups
def time_str ( for_dirname = False ) :
return datetime . datetime . now ( tz = pytz . timezone ( ' Asia/Shanghai ' ) ) . strftime ( ' % m- %d _ % H- % M- % S ' if for_dirname else ' [ % m- %d % H: % M: % S] ' )
def init_distributed_environ ( ) :
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if tmp . get_start_method ( allow_none = True ) is None :
tmp . set_start_method ( ' spawn ' )
global_rank , num_gpus = int ( os . environ . get ( ' RANK ' , ' error ' ) ) , torch . cuda . device_count ( )
local_rank = global_rank % num_gpus
torch . cuda . set_device ( local_rank )
tdist . init_process_group ( backend = ' nccl ' )
assert tdist . is_initialized ( ) , ' torch.distributed is not initialized! '
torch . backends . cudnn . benchmark = True
torch . backends . cudnn . deterministic = False
# print only when local_rank == 0 or print(..., force=True)
import builtins as __builtin__
builtin_print = __builtin__ . print
def prt ( msg , * args , * * kwargs ) :
force = kwargs . pop ( ' force ' , False )
if local_rank == 0 or force :
f_back = sys . _getframe ( ) . f_back
file_desc = f ' { f_back . f_code . co_filename : 24s } ' [ - 24 : ]
builtin_print ( f ' { time_str ( ) } ( { file_desc } , line { f_back . f_lineno : -4d } )=> { msg } ' , * args , * * kwargs )
__builtin__ . print = prt
tdist . barrier ( )
return tdist . get_world_size ( ) , global_rank , local_rank , torch . empty ( 1 ) . cuda ( ) . device
def create_model_opt ( args : FineTuneArgs ) - > Tuple [ torch . nn . Module , Callable , torch . nn . Module , DistributedDataParallel , ModelEmaV2 , Optimizer ] :
num_classes = 1000
model_without_ddp : torch . nn . Module = create_model ( args . model , num_classes = num_classes , drop_path_rate = args . drop_path ) . to ( args . device )
model_para = f ' { sum ( p . numel ( ) for p in model_without_ddp . parameters ( ) if p . requires_grad ) / 1e6 : .1f } M '
# create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEmaV2 ( model_without_ddp , decay = args . ema , device = args . device )
if args . sbn :
model_without_ddp = torch . nn . SyncBatchNorm . convert_sync_batchnorm ( model_without_ddp )
print ( f ' [model= { args . model } ] [#para= { model_para } , drop_path= { args . drop_path } , ema= { args . ema } ] { model_without_ddp } \n ' )
model = DistributedDataParallel ( model_without_ddp , device_ids = [ args . local_rank ] , find_unused_parameters = False , broadcast_buffers = False )
model . train ( )
opt_cls = {
' adam ' : AdamW , ' adamw ' : AdamW ,
' lamb ' : partial ( Lamb , max_grad_norm = 1e7 , always_adapt = True , bias_correction = False ) ,
}
param_groups : List [ dict ] = get_param_groups ( model_without_ddp , nowd_keys = { ' cls_token ' , ' pos_embed ' , ' mask_token ' , ' gamma ' } , lr_scale = args . lr_scale )
# param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float}
optimizer = opt_cls [ args . opt ] ( param_groups , lr = args . lr , weight_decay = 0 )
print ( f ' [optimizer= { type ( optimizer ) } ] ' )
mixup_fn = BatchMixup (
mixup_alpha = args . mixup , cutmix_alpha = 1.0 , cutmix_minmax = None ,
prob = 1.0 , switch_prob = 0.5 , mode = ' batch ' ,
label_smoothing = 0.1 , num_classes = num_classes
)
mixup_fn . mixup_enabled = args . mixup > 0.0
if ' lamb ' in args . opt :
# label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0
criterion = BinaryCrossEntropy ( smoothing = 0 , target_threshold = None )
else :
criterion = SoftTargetCrossEntropy ( )
print ( f ' [loss_fn] { criterion } ' )
print ( f ' [mixup_fn] { mixup_fn } ' )
return criterion , mixup_fn , model_without_ddp , model , model_ema , optimizer
def load_checkpoint ( resume_from , model_without_ddp , ema_module , optimizer ) :
if len ( resume_from ) == 0 or not os . path . exists ( resume_from ) :
raise AttributeError ( f ' ckpt ` { resume_from } ` not found! ' )
# return 0, '[no performance_desc]'
print ( f ' [try to resume from file ` { resume_from } `] ' )
checkpoint = torch . load ( resume_from , map_location = ' cpu ' )
assert checkpoint . get ( ' is_pretrain ' , False ) == False , ' Please do not use `*_withdecoder_1kpretrained_spark_style.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained_timm_style.pth` or `*_1kfinetuned*.pth` instead. '
ep_start , performance_desc = checkpoint . get ( ' epoch ' , - 1 ) + 1 , checkpoint . get ( ' performance_desc ' , ' [no performance_desc] ' )
missing , unexpected = model_without_ddp . load_state_dict ( checkpoint . get ( ' module ' , checkpoint ) , strict = False )
print ( f ' [load_checkpoint] missing_keys= { missing } ' )
print ( f ' [load_checkpoint] unexpected_keys= { unexpected } ' )
print ( f ' [load_checkpoint] ep_start= { ep_start } , performance_desc= { performance_desc } ' )
if ' optimizer ' in checkpoint :
optimizer . load_state_dict ( checkpoint [ ' optimizer ' ] )
if ' ema ' in checkpoint :
ema_module . load_state_dict ( checkpoint [ ' ema ' ] )
return ep_start , performance_desc
def save_checkpoint ( save_to , args , epoch , performance_desc , model_without_ddp_state , ema_state , optimizer_state ) :
checkpoint_path = os . path . join ( args . exp_dir , save_to )
if args . is_local_master :
to_save = {
' args ' : str ( args ) ,
' arch ' : args . model ,
' epoch ' : epoch ,
' performance_desc ' : performance_desc ,
' module ' : model_without_ddp_state ,
' ema ' : ema_state ,
' optimizer ' : optimizer_state ,
' is_pretrain ' : False ,
}
torch . save ( to_save , checkpoint_path )