You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

84 lines
2.5 KiB

# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch
import torch.nn.functional as F
from timm.models.resnet import ResNet
# hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet`
def get_downsample_ratio(self: ResNet) -> int:
return 32
# hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet`
def get_feature_map_channels(self: ResNet) -> List[int]:
# `self.feature_info` is maintained by `timm`
return [info['num_chs'] for info in self.feature_info[1:]]
# hack: override the forward function of `timm.models.resnet.ResNet`
def forward(self, x, hierarchical=False):
""" this forward function is a modified version of `timm.models.resnet.ResNet.forward`
>>> ResNet.forward
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.maxpool(x)
if hierarchical:
ls = []
x = self.layer1(x); ls.append(x)
x = self.layer2(x); ls.append(x)
x = self.layer3(x); ls.append(x)
x = self.layer4(x); ls.append(x)
return ls
else:
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
ResNet.get_downsample_ratio = get_downsample_ratio
ResNet.get_feature_map_channels = get_feature_map_channels
ResNet.forward = forward
@torch.no_grad()
def convnet_test():
from timm.models import create_model
cnn = create_model('resnet50')
print('get_downsample_ratio:', cnn.get_downsample_ratio())
print('get_feature_map_channels:', cnn.get_feature_map_channels())
downsample_ratio = cnn.get_downsample_ratio()
feature_map_channels = cnn.get_feature_map_channels()
# check the forward function
B, C, H, W = 4, 3, 224, 224
inp = torch.rand(B, C, H, W)
feats = cnn(inp, hierarchical=True)
assert isinstance(feats, list)
assert len(feats) == len(feature_map_channels)
print([tuple(t.shape) for t in feats])
# check the downsample ratio
feats = cnn(inp, hierarchical=True)
assert feats[-1].shape[-2] == H // downsample_ratio
assert feats[-1].shape[-1] == W // downsample_ratio
# check the channel number
for feat, ch in zip(feats, feature_map_channels):
assert feat.ndim == 4
assert feat.shape[1] == ch
if __name__ == '__main__':
convnet_test()