You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
3.4 KiB
89 lines
3.4 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import List |
|
from timm.models.registry import register_model |
|
|
|
|
|
class YourConvNet(nn.Module): |
|
""" |
|
This is a template for your custom ConvNet. |
|
It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`. |
|
You can refer to the implementations in `pretrain\models\resnet.py` for an example. |
|
""" |
|
|
|
def get_downsample_ratio(self) -> int: |
|
""" |
|
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). |
|
|
|
:return: the TOTAL downsample ratio of the ConvNet. |
|
E.g., for a ResNet-50, this should return 32. |
|
""" |
|
raise NotImplementedError |
|
|
|
def get_feature_map_channels(self) -> List[int]: |
|
""" |
|
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). |
|
|
|
:return: a list of the number of channels of each feature map. |
|
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048]. |
|
""" |
|
raise NotImplementedError |
|
|
|
def forward(self, inp_bchw: torch.Tensor, hierarchical=False): |
|
""" |
|
The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`). |
|
|
|
:param inp_bchw: input image tensor, shape: (batch_size, channels, height, width). |
|
:param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical). |
|
:return: |
|
- hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes). |
|
- hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`. |
|
E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map]. |
|
for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)] |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
@register_model |
|
def your_convnet_small(pretrained=False, **kwargs): |
|
raise NotImplementedError |
|
return YourConvNet(**kwargs) |
|
|
|
|
|
@torch.no_grad() |
|
def convnet_test(): |
|
from timm.models import create_model |
|
cnn = create_model('your_convnet_small') |
|
print('get_downsample_ratio:', cnn.get_downsample_ratio()) |
|
print('get_feature_map_channels:', cnn.get_feature_map_channels()) |
|
|
|
downsample_ratio = cnn.get_downsample_ratio() |
|
feature_map_channels = cnn.get_feature_map_channels() |
|
|
|
# check the forward function |
|
B, C, H, W = 4, 3, 224, 224 |
|
inp = torch.rand(B, C, H, W) |
|
feats = cnn(inp, hierarchical=True) |
|
assert isinstance(feats, list) |
|
assert len(feats) == len(feature_map_channels) |
|
print([tuple(t.shape) for t in feats]) |
|
|
|
# check the downsample ratio |
|
feats = cnn(inp, hierarchical=True) |
|
assert feats[-1].shape[-2] == H // downsample_ratio |
|
assert feats[-1].shape[-1] == W // downsample_ratio |
|
|
|
# check the channel number |
|
for feat, ch in zip(feats, feature_map_channels): |
|
assert feat.ndim == 4 |
|
assert feat.shape[1] == ch |
|
|
|
|
|
if __name__ == '__main__': |
|
convnet_test()
|
|
|