You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.5 KiB

2 years ago
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
2 years ago
import torch.nn.functional as F
from timm.models.resnet import ResNet
def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
""" this forward function is a modified version of `timm.models.resnet.ResNet.forward`
>>> ResNet.forward
"""
2 years ago
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.maxpool(x)
ls = []
x = self.layer1(x)
ls.append(x if hierarchy >= 4 else None)
2 years ago
x = self.layer2(x)
ls.append(x if hierarchy >= 3 else None)
2 years ago
x = self.layer3(x)
ls.append(x if hierarchy >= 2 else None)
2 years ago
x = self.layer4(x)
ls.append(x if hierarchy >= 1 else None)
2 years ago
if hierarchy:
return ls
2 years ago
else:
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
ResNet.forward = forward
if __name__ == '__main__':
from timm.models import create_model
r50 = create_model('resnet50')
def prt(lst):
print([tuple(t.shape) if t is not None else '(None)' for t in lst])
2 years ago
with torch.no_grad():
inp = torch.rand(2, 3, 224, 224)
prt(r50(inp))
prt(r50(inp, hierarchy=1))
prt(r50(inp, hierarchy=2))
prt(r50(inp, hierarchy=3))
prt(r50(inp, hierarchy=4))