[upd] 1. refactor a lot to simplify the pretraining codes; 2. add tutorial for customizing your own CNN model; 3. update some READMEs

main
keyu-tian 2 years ago
parent 46f9ad2871
commit 6ffe453fa5
  1. 14
      README.md
  2. 4
      downstream_mmdet/README.md
  3. 8
      downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
  4. 31
      pretrain/README.md
  5. 4
      pretrain/decoder.py
  6. 12
      pretrain/dist.py
  7. 10
      pretrain/encoder.py
  8. 4
      pretrain/main.py
  9. 44
      pretrain/models/__init__.py
  10. 29
      pretrain/models/convnext.py
  11. 89
      pretrain/models/custom.py
  12. 76
      pretrain/models/resnet.py
  13. 2
      pretrain/sampler.py
  14. 24
      pretrain/spark.py
  15. 23
      pretrain/utils/arg_util.py

@ -25,7 +25,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-
## 🔥 News
- On **Mar. 22nd (UTC+8 8pm; UTC+0 12am)** another livestream would be held at [极市平台-bilibili](https://live.bilibili.com/3344545).
- On **Mar. 22nd (UTC+8 8pm)** another livestream would be held at 极市平台-Bilibili! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1Da4y1T7mr/)]
- The share on [TechBeat (将门创投)](https://www.techbeat.net/talk-info?id=758) is scheduled on **Mar. 16th (UTC+8 8pm)** too! [[`📹Recorded Video`](https://www.techbeat.net/talk-info?id=758)]
- We are honored to be invited by Synced ("机器之心机动组 视频号" on WeChat) to give a talk about SparK on **Feb. 27th (UTC+0 11am, UTC+8 7pm)**, welcome! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1J54y1u7U3/)]
- This work got accepted to ICLR 2023 as a Spotlight (notable-top-25%).
@ -44,7 +44,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-
<!-- ## 📺 Video demo (we use [these ppt slides](https://github.com/keyu-tian/SparK/releases/tag/file_sharing) to make the animated video) -->
<!-- https://user-images.githubusercontent.com/6366788/213662770-5f814de0-cbe8-48d9-8235-e8907fd81e0e.mp4 -->
## 🕹 CoLab Visualization Demo
## 🕹 Colab Visualization Demo
Check [pretrain/viz_reconstruction.ipynb](pretrain/viz_reconstruction.ipynb) for visualizing the reconstruction of SparK pretrained models, like:
@ -94,9 +94,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
<summary>catalog</summary>
- [x] Pretraining code
- [x] Pretraining toturial for custom CNN model ([pretrain/models/custom.py](pretrain/models/custom.py))
- [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
- [x] Finetuning code
- [x] Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
- [ ] Weights & visualization playground on `Huggingface`
- [ ] Weights & visualization playground in `huggingface`
- [ ] Weights in `timm`
</details>
@ -128,7 +129,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
We highly recommended you to use `torch==1.10.0`, `torchvision==0.11.1`, and `timm==0.5.4` for reproduction.
Check [INSTALL.md](INSTALL.md) to install all pip dependencies.
- **Pretraining** all models on ImageNet-1k: &nbsp;see [pretrain/](pretrain)
- **Pretraining**
- all ResNets and ConvNeXts on ImageNet-1k: &nbsp;see [pretrain/](pretrain)
- **your own CNN models**: &nbsp;see [pretrain/](pretrain), especially [pretrain/models/custom.py](pretrain/models/custom.py)
- **Finetuning**
- all models on ImageNet: &nbsp;check [downstream_imagenet/](downstream_imagenet) for subsequent instructions.

@ -19,9 +19,9 @@ This `downstream_mmdet` is isolated from pre-training codes. One can treat this
<p>
## Installation [MMDetection with commit 6a979e2](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ConvNeXt on COCO
## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO
We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions.
Note the COCO dataset folder should be at `downstream_mmdet/data/coco`.

@ -23,13 +23,13 @@ _base_ = [
model = dict(
backbone=dict(
in_chans=3,
depths=[3, 3, 27, 3], # [modified] according to tiny to base
dims=[128, 256, 512, 1024], # [modified] according to tiny to base
drop_path_rate=0.5, # [modified] according to tiny to base
depths=[3, 3, 27, 3], # [modified] according to tiny-to-base
dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base
drop_path_rate=0.5, # [modified] according to tiny-to-base
layer_scale_init_value=1.0,
out_indices=[0, 1, 2, 3],
),
neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny to base
neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

@ -1,17 +1,30 @@
## Preparation for ImageNet-1k pre-training
See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
## Tutorial for customizing your own CNN model
See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed to do is:
- implementing member function `get_downsample_ratio` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- implementing member function `get_feature_map_channels` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- implementing member function `forward` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- add default kwargs of `your_convnet(...)` in [/pretrain/models/__init__.py](/pretrain/models/__init__.py).
Then you can use `--model=your_convnet` in the pre-training script.
## Pre-training Any Model on ImageNet-1k (224x224)
For pre-training, run [main.sh](https://github.com/keyu-tian/SparK/blob/main/pretrain/main.sh) with bash.
For pre-training, run [/pretrain/main.sh](/pretrain/main.sh) with bash.
It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script.
We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts).
Their names and **default values** can be found in [utils/arg_util.py line24-47](https://github.com/keyu-tian/SparK/blob/main/pretrain/utils/arg_util.py#L24).
Their names and **default values** can be found in [/pretrain/utils/arg_util.py line24-47](/pretrain/utils/arg_util.py).
These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
Here is an example command pre-training a ResNet50 on single machine with 8 GPUs:
@ -69,8 +82,8 @@ Add `--resume_from=path/to/<model>still_pretraining.pth` to resume from a saved
## Regarding sparse convolution
We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardwares.
As can be found in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
We also define some sparse pooling or normalization layers in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py).
As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py).
All these "sparse" layers are implemented through pytorch built-in operators.
@ -80,10 +93,8 @@ In SparK, the mask patch size **equals to** the downsample ratio of the CNN mode
Here is the reason: when we do mask, we:
1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [line86-87](https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py#L86), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map.
3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [line16](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L21)) with larger resolutions .
1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map.
3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py)) with larger resolutions .
So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
Note that the `forward` function of this CNN should have an arg named `hierarchy`. You can look at https://github.com/keyu-tian/SparK/blob/main/pretrain/models/convnext.py#L78 to see what `hierarchy` means and how to handle it.
After that, you can simply run `main.sh` with `--hierarchy=3` and see if it works.
See `Tutorial for customizing your own CNN model` above.

@ -32,12 +32,12 @@ class UNetBlock(nn.Module):
class LightDecoder(nn.Module):
def __init__(self, up_sample_ratio, width=768, sbn=True):
def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
super().__init__()
self.width = width
assert is_pow2n(up_sample_ratio)
n = round(math.log2(up_sample_ratio))
channels = [self.width // 2 ** i for i in range(n + 1)]
channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])])
self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True)

@ -8,12 +8,11 @@ import os
from typing import List
from typing import Union
import sys
import torch
import torch.distributed as tdist
import torch.multiprocessing as mp
from torch.distributed import barrier as __barrier
barrier = __barrier
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
__initialized = False
@ -23,6 +22,10 @@ def initialized():
def initialize(backend='nccl'):
if not torch.cuda.is_available():
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
return
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
@ -64,6 +67,11 @@ def is_local_master():
return __local_rank == 0
def barrier():
if __initialized:
tdist.barrier()
def parallelize(net, syncbn=False):
if syncbn:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)

@ -156,10 +156,10 @@ class SparseConvNeXtBlock(nn.Module):
class SparseEncoder(nn.Module):
def __init__(self, conv_model, input_size, downsample_raito, encoder_fea_dim, sbn=False, verbose=False):
def __init__(self, cnn, input_size, sbn=False, verbose=False):
super(SparseEncoder, self).__init__()
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=conv_model, verbose=verbose, sbn=sbn)
self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels()
@staticmethod
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
@ -204,5 +204,5 @@ class SparseEncoder(nn.Module):
del m
return oup
def forward(self, x, hierarchy):
return self.sp_cnn(x, hierarchy=hierarchy)
def forward(self, x):
return self.sp_cnn(x, hierarchical=True)

@ -37,7 +37,7 @@ def main_pt():
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed,
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn
)
@ -49,7 +49,7 @@ def main_pt():
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
model_without_ddp = SparK(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
densify_norm=args.densify_norm, sbn=args.sbn, hierarchy=args.hierarchy,
densify_norm=args.densify_norm, sbn=args.sbn,
).to(args.device)
print(f'[PT model] model = {model_without_ddp}\n')
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)

@ -30,38 +30,30 @@ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
model_alias_to_fullname = {
'res50': 'resnet50',
'res101': 'resnet101',
'res152': 'resnet152',
'res200': 'resnet200',
'cnxS': 'convnext_small',
'cnxB': 'convnext_base',
'cnxL': 'convnext_large',
pretrain_default_model_kwargs = {
'your_convnet': dict(),
'resnet50': dict(drop_path_rate=0.05),
'resnet101': dict(drop_path_rate=0.08),
'resnet152': dict(drop_path_rate=0.10),
'resnet200': dict(drop_path_rate=0.15),
'convnext_small': dict(sparse=True, drop_path_rate=0.2),
'convnext_base': dict(sparse=True, drop_path_rate=0.3),
'convnext_large': dict(sparse=True, drop_path_rate=0.4),
}
model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()}
pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel
'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048],
'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048],
'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048],
'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048],
'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768],
'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024],
'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536],
}
for v in pre_train_d.values():
v[0]['pretrained'] = False
v[0]['num_classes'] = 0
v[0]['global_pool'] = ''
for kw in pretrain_default_model_kwargs.values():
kw['pretrained'] = False
kw['num_classes'] = 0
kw['global_pool'] = ''
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
from encoder import SparseEncoder
kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name]
kwargs = pretrain_default_model_kwargs[name]
if drop_path_rate != 0:
kwargs['drop_path_rate'] = drop_path_rate
print(f'[build_sparse_encoder] model kwargs={kwargs}')
return SparseEncoder(create_model(name, **kwargs), input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, sbn=sbn, verbose=verbose)
cnn = create_model(name, **kwargs)
return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose)

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
#
# This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
from typing import List
import torch
import torch.nn as nn
@ -34,7 +35,7 @@ class ConvNeXt(nn.Module):
sparse=True,
):
super().__init__()
self.dims: List[int] = dims
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
@ -75,13 +76,19 @@ class ConvNeXt(nn.Module):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def forward(self, x, hierarchy=0):
def get_downsample_ratio(self) -> int:
return 32
def get_feature_map_channels(self) -> List[int]:
return self.dims
def forward(self, x, hierarchical=False):
if hierarchical:
ls = []
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
ls.append(x if hierarchy >= 4-i else None)
if hierarchy:
ls.append(x)
return ls
else:
return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls)
@ -116,17 +123,3 @@ def convnext_large(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
return model
if __name__ == '__main__':
from timm.models import create_model
cnx = create_model('convnext_small', sparse=False)
def prt(lst):
print([tuple(t.shape) if t is not None else '(None)' for t in lst])
with torch.no_grad():
inp = torch.rand(2, 3, 224, 224)
prt(cnx(inp))
prt(cnx(inp, hierarchy=1))
prt(cnx(inp, hierarchy=2))
prt(cnx(inp, hierarchy=3))
prt(cnx(inp, hierarchy=4))

@ -0,0 +1,89 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from typing import List
from timm.models.registry import register_model
class YourConvNet(nn.Module):
"""
This is a template for your custom ConvNet.
It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
You can refer to the implementations in `pretrain\models\resnet.py` for an example.
"""
def get_downsample_ratio(self) -> int:
"""
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
:return: the TOTAL downsample ratio of the ConvNet.
E.g., for a ResNet-50, this should return 32.
"""
raise NotImplementedError
def get_feature_map_channels(self) -> List[int]:
"""
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
:return: a list of the number of channels of each feature map.
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
"""
raise NotImplementedError
def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
"""
The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
:param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
:param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
:return:
- hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
- hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
"""
raise NotImplementedError
@register_model
def your_convnet_small(pretrained=False, **kwargs):
raise NotImplementedError
return YourConvNet(**kwargs)
@torch.no_grad()
def convnet_test():
from timm.models import create_model
cnn = create_model('your_convnet_small')
print('get_downsample_ratio:', cnn.get_downsample_ratio())
print('get_feature_map_channels:', cnn.get_feature_map_channels())
downsample_ratio = cnn.get_downsample_ratio()
feature_map_channels = cnn.get_feature_map_channels()
# check the forward function
B, C, H, W = 4, 3, 224, 224
inp = torch.rand(B, C, H, W)
feats = cnn(inp, hierarchical=True)
assert isinstance(feats, list)
assert len(feats) == len(feature_map_channels)
print([tuple(t.shape) for t in feats])
# check the downsample ratio
feats = cnn(inp, hierarchical=True)
assert feats[-1].shape[-2] == H // downsample_ratio
assert feats[-1].shape[-1] == W // downsample_ratio
# check the channel number
for feat, ch in zip(feats, feature_map_channels):
assert feat.ndim == 4
assert feat.shape[1] == ch
if __name__ == '__main__':
convnet_test()

@ -3,15 +3,26 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List
import torch
import torch.nn.functional as F
from timm.models.resnet import ResNet
def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
# hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet`
def get_downsample_ratio(self: ResNet) -> int:
return 32
# hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet`
def get_feature_map_channels(self: ResNet) -> List[int]:
# `self.feature_info` is maintained by `timm`
return [info['num_chs'] for info in self.feature_info[1:]]
# hack: override the forward function of `timm.models.resnet.ResNet`
def forward(self, x, hierarchical=False):
""" this forward function is a modified version of `timm.models.resnet.ResNet.forward`
>>> ResNet.forward
"""
@ -20,17 +31,12 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
x = self.act1(x)
x = self.maxpool(x)
if hierarchical:
ls = []
x = self.layer1(x)
ls.append(x if hierarchy >= 4 else None)
x = self.layer2(x)
ls.append(x if hierarchy >= 3 else None)
x = self.layer3(x)
ls.append(x if hierarchy >= 2 else None)
x = self.layer4(x)
ls.append(x if hierarchy >= 1 else None)
if hierarchy:
x = self.layer1(x); ls.append(x)
x = self.layer2(x); ls.append(x)
x = self.layer3(x); ls.append(x)
x = self.layer4(x); ls.append(x)
return ls
else:
x = self.global_pool(x)
@ -40,19 +46,39 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
return x
ResNet.get_downsample_ratio = get_downsample_ratio
ResNet.get_feature_map_channels = get_feature_map_channels
ResNet.forward = forward
if __name__ == '__main__':
@torch.no_grad()
def convnet_test():
from timm.models import create_model
r50 = create_model('resnet50')
def prt(lst):
print([tuple(t.shape) if t is not None else '(None)' for t in lst])
with torch.no_grad():
inp = torch.rand(2, 3, 224, 224)
prt(r50(inp))
prt(r50(inp, hierarchy=1))
prt(r50(inp, hierarchy=2))
prt(r50(inp, hierarchy=3))
prt(r50(inp, hierarchy=4))
cnn = create_model('resnet50')
print('get_downsample_ratio:', cnn.get_downsample_ratio())
print('get_feature_map_channels:', cnn.get_feature_map_channels())
downsample_ratio = cnn.get_downsample_ratio()
feature_map_channels = cnn.get_feature_map_channels()
# check the forward function
B, C, H, W = 4, 3, 224, 224
inp = torch.rand(B, C, H, W)
feats = cnn(inp, hierarchical=True)
assert isinstance(feats, list)
assert len(feats) == len(feature_map_channels)
print([tuple(t.shape) for t in feats])
# check the downsample ratio
feats = cnn(inp, hierarchical=True)
assert feats[-1].shape[-2] == H // downsample_ratio
assert feats[-1].shape[-1] == W // downsample_ratio
# check the channel number
for feat, ch in zip(feats, feature_map_channels):
assert feat.ndim == 4
assert feat.shape[1] == ch
if __name__ == '__main__':
convnet_test()

@ -19,7 +19,7 @@ def worker_init_fn(worker_id):
class DistInfiniteBatchSampler(Sampler):
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True):
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True):
assert glb_batch_size % world_size == 0
self.world_size, self.rank = world_size, rank
self.dataset_len = dataset_len

@ -20,7 +20,7 @@ from decoder import LightDecoder
class SparK(nn.Module):
def __init__(
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
mask_ratio=0.6, densify_norm='bn', sbn=False, hierarchy=4,
mask_ratio=0.6, densify_norm='bn', sbn=False,
):
super().__init__()
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
@ -33,15 +33,23 @@ class SparK(nn.Module):
self.dense_decoder = dense_decoder
self.sbn = sbn
self.hierarchy = hierarchy
self.hierarchy = len(sparse_encoder.enc_feat_map_chs)
self.densify_norm_str = densify_norm.lower()
self.densify_norms = nn.ModuleList()
self.densify_projs = nn.ModuleList()
self.mask_tokens = nn.ParameterList()
# build the `densify` layers
e_width, d_width = self.sparse_encoder.fea_dim, self.dense_decoder.width
for i in range(self.hierarchy):
e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width
e_widths: List[int]
for i in range(self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ...
e_width = e_widths.pop()
# create mask token
p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(p)
# create densify norm
if self.densify_norm_str == 'bn':
densify_norm = (encoder.SparseSyncBatchNorm2d if self.sbn else encoder.SparseBatchNorm2d)(e_width)
elif self.densify_norm_str == 'ln':
@ -50,6 +58,7 @@ class SparK(nn.Module):
densify_norm = nn.Identity()
self.densify_norms.append(densify_norm)
# create densify proj
if i == 0 and e_width == d_width:
densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj')
@ -59,10 +68,7 @@ class SparK(nn.Module):
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)')
self.densify_projs.append(densify_proj)
p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(p)
e_width //= 2
# todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
d_width //= 2
print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')
@ -89,7 +95,7 @@ class SparK(nn.Module):
masked_bchw = inp_bchw * active_b1hw
# step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, hierarchy=self.hierarchy)
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw)
fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest
# step3. Densify: get hierarchical dense features for decoding

@ -15,19 +15,16 @@ import dist
class Args(Tap):
# environment
exp_name: str
exp_dir: str
data_path: str
exp_name: str = 'your_exp_name'
exp_dir: str = 'your_exp_dir' # will be created if not exists
data_path: str = 'imagenet_data_path'
resume_from: str = '' # resume from some checkpoint.pth
seed: int = 1
# SparK hyperparameters
mask: float = 0.6
hierarchy: int = 4
mask: float = 0.6 # mask ratio, should be in (0, 1)
# encoder hyperparameters
model: str = 'res50'
model_alias: str = 'res50'
model: str = 'resnet50'
input_size: int = 224
sbn: bool = True
@ -70,7 +67,7 @@ class Args(Tap):
@property
def is_resnet(self):
return 'resnet' in self.model or 'res' in self.model_alias
return 'resnet' in self.model
def log_epoch(self):
if not dist.is_local_master():
@ -96,7 +93,6 @@ class Args(Tap):
def init_dist_and_get_args():
from utils import misc
from models import model_alias_to_fullname, model_fullname_to_alias
# initialize
args = Args(explicit_bool=True).parse_args()
@ -116,10 +112,6 @@ def init_dist_and_get_args():
misc.init_distributed_environ(exp_dir=args.exp_dir)
# update args
if args.model in model_alias_to_fullname.keys():
args.model = model_alias_to_fullname[args.model]
args.model_alias = model_fullname_to_alias[args.model]
args.first_logging = True
args.device = dist.get_device()
args.batch_size_per_gpu = args.bs // dist.get_world_size()
@ -137,7 +129,4 @@ def init_dist_and_get_args():
args.lr = args.base_lr * args.glb_batch_size / 256
args.wde = args.wde or args.wd
if args.hierarchy < 1:
args.hierarchy = 1
return args

Loading…
Cancel
Save