[upd] 1. refactor a lot to simplify the pretraining codes; 2. add tutorial for customizing your own CNN model; 3. update some READMEs
parent
46f9ad2871
commit
6ffe453fa5
15 changed files with 250 additions and 132 deletions
@ -0,0 +1,89 @@ |
|||||||
|
# Copyright (c) ByteDance, Inc. and its affiliates. |
||||||
|
# All rights reserved. |
||||||
|
# |
||||||
|
# This source code is licensed under the license found in the |
||||||
|
# LICENSE file in the root directory of this source tree. |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.nn as nn |
||||||
|
from typing import List |
||||||
|
from timm.models.registry import register_model |
||||||
|
|
||||||
|
|
||||||
|
class YourConvNet(nn.Module): |
||||||
|
""" |
||||||
|
This is a template for your custom ConvNet. |
||||||
|
It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`. |
||||||
|
You can refer to the implementations in `pretrain\models\resnet.py` for an example. |
||||||
|
""" |
||||||
|
|
||||||
|
def get_downsample_ratio(self) -> int: |
||||||
|
""" |
||||||
|
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). |
||||||
|
|
||||||
|
:return: the TOTAL downsample ratio of the ConvNet. |
||||||
|
E.g., for a ResNet-50, this should return 32. |
||||||
|
""" |
||||||
|
raise NotImplementedError |
||||||
|
|
||||||
|
def get_feature_map_channels(self) -> List[int]: |
||||||
|
""" |
||||||
|
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). |
||||||
|
|
||||||
|
:return: a list of the number of channels of each feature map. |
||||||
|
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048]. |
||||||
|
""" |
||||||
|
raise NotImplementedError |
||||||
|
|
||||||
|
def forward(self, inp_bchw: torch.Tensor, hierarchical=False): |
||||||
|
""" |
||||||
|
The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`). |
||||||
|
|
||||||
|
:param inp_bchw: input image tensor, shape: (batch_size, channels, height, width). |
||||||
|
:param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical). |
||||||
|
:return: |
||||||
|
- hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes). |
||||||
|
- hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`. |
||||||
|
E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map]. |
||||||
|
for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)] |
||||||
|
""" |
||||||
|
raise NotImplementedError |
||||||
|
|
||||||
|
|
||||||
|
@register_model |
||||||
|
def your_convnet_small(pretrained=False, **kwargs): |
||||||
|
raise NotImplementedError |
||||||
|
return YourConvNet(**kwargs) |
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad() |
||||||
|
def convnet_test(): |
||||||
|
from timm.models import create_model |
||||||
|
cnn = create_model('your_convnet_small') |
||||||
|
print('get_downsample_ratio:', cnn.get_downsample_ratio()) |
||||||
|
print('get_feature_map_channels:', cnn.get_feature_map_channels()) |
||||||
|
|
||||||
|
downsample_ratio = cnn.get_downsample_ratio() |
||||||
|
feature_map_channels = cnn.get_feature_map_channels() |
||||||
|
|
||||||
|
# check the forward function |
||||||
|
B, C, H, W = 4, 3, 224, 224 |
||||||
|
inp = torch.rand(B, C, H, W) |
||||||
|
feats = cnn(inp, hierarchical=True) |
||||||
|
assert isinstance(feats, list) |
||||||
|
assert len(feats) == len(feature_map_channels) |
||||||
|
print([tuple(t.shape) for t in feats]) |
||||||
|
|
||||||
|
# check the downsample ratio |
||||||
|
feats = cnn(inp, hierarchical=True) |
||||||
|
assert feats[-1].shape[-2] == H // downsample_ratio |
||||||
|
assert feats[-1].shape[-1] == W // downsample_ratio |
||||||
|
|
||||||
|
# check the channel number |
||||||
|
for feat, ch in zip(feats, feature_map_channels): |
||||||
|
assert feat.ndim == 4 |
||||||
|
assert feat.shape[1] == ch |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
convnet_test() |
Loading…
Reference in new issue