You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
84 lines
2.5 KiB
84 lines
2.5 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
from typing import List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from timm.models.resnet import ResNet |
|
|
|
|
|
# hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet` |
|
def get_downsample_ratio(self: ResNet) -> int: |
|
return 32 |
|
|
|
|
|
# hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet` |
|
def get_feature_map_channels(self: ResNet) -> List[int]: |
|
# `self.feature_info` is maintained by `timm` |
|
return [info['num_chs'] for info in self.feature_info[1:]] |
|
|
|
|
|
# hack: override the forward function of `timm.models.resnet.ResNet` |
|
def forward(self, x, hierarchical=False): |
|
""" this forward function is a modified version of `timm.models.resnet.ResNet.forward` |
|
>>> ResNet.forward |
|
""" |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.act1(x) |
|
x = self.maxpool(x) |
|
|
|
if hierarchical: |
|
ls = [] |
|
x = self.layer1(x); ls.append(x) |
|
x = self.layer2(x); ls.append(x) |
|
x = self.layer3(x); ls.append(x) |
|
x = self.layer4(x); ls.append(x) |
|
return ls |
|
else: |
|
x = self.global_pool(x) |
|
if self.drop_rate: |
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training) |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
ResNet.get_downsample_ratio = get_downsample_ratio |
|
ResNet.get_feature_map_channels = get_feature_map_channels |
|
ResNet.forward = forward |
|
|
|
|
|
@torch.no_grad() |
|
def convnet_test(): |
|
from timm.models import create_model |
|
cnn = create_model('resnet50') |
|
print('get_downsample_ratio:', cnn.get_downsample_ratio()) |
|
print('get_feature_map_channels:', cnn.get_feature_map_channels()) |
|
|
|
downsample_ratio = cnn.get_downsample_ratio() |
|
feature_map_channels = cnn.get_feature_map_channels() |
|
|
|
# check the forward function |
|
B, C, H, W = 4, 3, 224, 224 |
|
inp = torch.rand(B, C, H, W) |
|
feats = cnn(inp, hierarchical=True) |
|
assert isinstance(feats, list) |
|
assert len(feats) == len(feature_map_channels) |
|
print([tuple(t.shape) for t in feats]) |
|
|
|
# check the downsample ratio |
|
feats = cnn(inp, hierarchical=True) |
|
assert feats[-1].shape[-2] == H // downsample_ratio |
|
assert feats[-1].shape[-1] == W // downsample_ratio |
|
|
|
# check the channel number |
|
for feat, ch in zip(feats, feature_map_channels): |
|
assert feat.ndim == 4 |
|
assert feat.shape[1] == ch |
|
|
|
|
|
if __name__ == '__main__': |
|
convnet_test()
|
|
|