You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
3.1 KiB
105 lines
3.1 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import math |
|
|
|
import torch.nn.functional as F |
|
from timm.models.resnet import ResNet |
|
|
|
|
|
def forward_features(self, x, pyramid: int): # pyramid: 0, 1, 2, 3, 4 |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.act1(x) |
|
x = self.maxpool(x) |
|
|
|
ls = [] |
|
x = self.layer1(x) |
|
if pyramid: ls.append(x) |
|
x = self.layer2(x) |
|
if pyramid: ls.append(x) |
|
x = self.layer3(x) |
|
if pyramid: ls.append(x) |
|
x = self.layer4(x) |
|
if pyramid: ls.append(x) |
|
|
|
if pyramid: |
|
for i in range(len(ls)-pyramid-1, -1, -1): |
|
del ls[i] |
|
return [None] * (4 - pyramid) + ls |
|
else: |
|
return x |
|
|
|
|
|
def forward(self, x, pyramid=0): |
|
if pyramid == 0: |
|
x = self.forward_features(x, pyramid=pyramid) |
|
x = self.global_pool(x) |
|
if self.drop_rate: |
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training) |
|
x = self.fc(x) |
|
return x |
|
else: |
|
return self.forward_features(x, pyramid=pyramid) |
|
|
|
|
|
def resnets_get_layer_id_and_scale_exp(self, para_name: str): |
|
# stages: |
|
# 50 : [3, 4, 6, 3] |
|
# 101 : [3, 4, 23, 3] |
|
# 152 : [3, 8, 36, 3] |
|
# 200 : [3, 24, 36, 3] |
|
# eca269d: [3, 30, 48, 8] |
|
|
|
L2, L3 = len(self.layer2), len(self.layer3) |
|
if L2 == 4 and L3 == 6: |
|
blk2, blk3 = 2, 3 |
|
elif L2 == 4 and L3 == 23: |
|
blk2, blk3 = 2, 3 |
|
elif L2 == 8 and L3 == 36: |
|
blk2, blk3 = 4, 4 |
|
elif L2 == 24 and L3 == 36: |
|
blk2, blk3 = 4, 4 |
|
elif L2 == 30 and L3 == 48: |
|
blk2, blk3 = 5, 6 |
|
else: |
|
raise NotImplementedError |
|
|
|
N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5) |
|
N = 2 + N2 + N3 |
|
if para_name.startswith('layer'): # 1, 2, 3, 4, 5 |
|
stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1]) |
|
if stage_id == 1: |
|
layer_id = 1 |
|
elif stage_id == 2: |
|
layer_id = 2 + block_id // blk2 # 2, 3 |
|
elif stage_id == 3: |
|
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11 |
|
else: # == 4 |
|
layer_id = N # r50: 6 r101: 12 |
|
elif para_name.startswith('fc.'): |
|
layer_id = N+1 # r50: 7 r101: 13 |
|
else: |
|
layer_id = 0 |
|
|
|
return layer_id, N+1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0 |
|
|
|
|
|
ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp |
|
ResNet.forward_features = forward_features |
|
ResNet.forward = forward |
|
|
|
|
|
if __name__ == '__main__': |
|
import torch |
|
from timm.models import create_model |
|
r = create_model('resnet50') |
|
with torch.no_grad(): |
|
print(r(torch.rand(2, 3, 224, 224)).shape) |
|
print(r(torch.rand(2, 3, 224, 224), pyramid=1)) |
|
print(r(torch.rand(2, 3, 224, 224), pyramid=2)) |
|
print(r(torch.rand(2, 3, 224, 224), pyramid=3)) |
|
print(r(torch.rand(2, 3, 224, 224), pyramid=4))
|
|
|