You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
3.4 KiB
105 lines
3.4 KiB
2 years ago
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import math
|
||
|
|
||
|
import torch
|
||
|
from timm.data import Mixup
|
||
|
from timm.loss import BinaryCrossEntropy, SoftTargetCrossEntropy
|
||
|
from timm.models.layers import drop
|
||
|
from timm.models.resnet import ResNet
|
||
|
|
||
|
from .convnext_official import ConvNeXt
|
||
|
|
||
|
|
||
|
def convnext_get_layer_id_and_scale_exp(self: ConvNeXt, para_name: str):
|
||
|
N = 12 if len(self.stages[-2]) > 9 else 6
|
||
|
if para_name.startswith("downsample_layers"):
|
||
|
stage_id = int(para_name.split('.')[1])
|
||
|
if stage_id == 0:
|
||
|
layer_id = 0
|
||
|
elif stage_id == 1 or stage_id == 2:
|
||
|
layer_id = stage_id + 1
|
||
|
else: # stage_id == 3:
|
||
|
layer_id = N
|
||
|
elif para_name.startswith("stages"):
|
||
|
stage_id = int(para_name.split('.')[1])
|
||
|
block_id = int(para_name.split('.')[2])
|
||
|
if stage_id == 0 or stage_id == 1:
|
||
|
layer_id = stage_id + 1
|
||
|
elif stage_id == 2:
|
||
|
layer_id = 3 + block_id // 3
|
||
|
else: # stage_id == 3:
|
||
|
layer_id = N
|
||
|
else:
|
||
|
layer_id = N + 1 # after backbone
|
||
|
|
||
|
return layer_id, N + 1 - layer_id
|
||
|
|
||
|
|
||
|
def resnets_get_layer_id_and_scale_exp(self: ResNet, para_name: str):
|
||
|
# stages:
|
||
|
# 50 : [3, 4, 6, 3]
|
||
|
# 101 : [3, 4, 23, 3]
|
||
|
# 152 : [3, 8, 36, 3]
|
||
|
# 200 : [3, 24, 36, 3]
|
||
|
# eca269d: [3, 30, 48, 8]
|
||
|
|
||
|
L2, L3 = len(self.layer2), len(self.layer3)
|
||
|
if L2 == 4 and L3 == 6:
|
||
|
blk2, blk3 = 2, 3
|
||
|
elif L2 == 4 and L3 == 23:
|
||
|
blk2, blk3 = 2, 3
|
||
|
elif L2 == 8 and L3 == 36:
|
||
|
blk2, blk3 = 4, 4
|
||
|
elif L2 == 24 and L3 == 36:
|
||
|
blk2, blk3 = 4, 4
|
||
|
elif L2 == 30 and L3 == 48:
|
||
|
blk2, blk3 = 5, 6
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5)
|
||
|
N = 2 + N2 + N3
|
||
|
if para_name.startswith('layer'): # 1, 2, 3, 4, 5
|
||
|
stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1])
|
||
|
if stage_id == 1:
|
||
|
layer_id = 1
|
||
|
elif stage_id == 2:
|
||
|
layer_id = 2 + block_id // blk2 # 2, 3
|
||
|
elif stage_id == 3:
|
||
|
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11
|
||
|
else: # == 4
|
||
|
layer_id = N # r50: 6 r101: 12
|
||
|
elif para_name.startswith('fc.'):
|
||
|
layer_id = N + 1 # r50: 7 r101: 13
|
||
|
else:
|
||
|
layer_id = 0
|
||
|
|
||
|
return layer_id, N + 1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0
|
||
|
|
||
|
|
||
|
def _ex_repr(self):
|
||
|
return ', '.join(
|
||
|
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
|
||
|
for k, v in vars(self).items()
|
||
|
if not k.startswith('_') and k != 'training'
|
||
|
and not isinstance(v, (torch.nn.Module, torch.Tensor))
|
||
|
)
|
||
|
|
||
|
|
||
|
# IMPORTANT: update some member functions
|
||
|
__UPDATED = False
|
||
|
if not __UPDATED:
|
||
|
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, BinaryCrossEntropy, Mixup, drop.DropPath):
|
||
|
if hasattr(clz, 'extra_repr'):
|
||
|
clz.extra_repr = _ex_repr
|
||
|
else:
|
||
|
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
|
||
|
ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp
|
||
|
ConvNeXt.get_layer_id_and_scale_exp = convnext_get_layer_id_and_scale_exp
|
||
|
__UPDATED = True
|