diff --git a/README.md b/README.md
index 22d4a58..d345199 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-
## 🔥 News
-- On **Mar. 22nd (UTC+8 8pm; UTC+0 12am)** another livestream would be held at [极市平台-bilibili](https://live.bilibili.com/3344545).
+- On **Mar. 22nd (UTC+8 8pm)** another livestream would be held at 极市平台-Bilibili! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1Da4y1T7mr/)]
- The share on [TechBeat (将门创投)](https://www.techbeat.net/talk-info?id=758) is scheduled on **Mar. 16th (UTC+8 8pm)** too! [[`📹Recorded Video`](https://www.techbeat.net/talk-info?id=758)]
- We are honored to be invited by Synced ("机器之心机动组 视频号" on WeChat) to give a talk about SparK on **Feb. 27th (UTC+0 11am, UTC+8 7pm)**, welcome! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1J54y1u7U3/)]
- This work got accepted to ICLR 2023 as a Spotlight (notable-top-25%).
@@ -44,7 +44,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-
-## 🕹️ CoLab Visualization Demo
+## 🕹️ Colab Visualization Demo
Check [pretrain/viz_reconstruction.ipynb](pretrain/viz_reconstruction.ipynb) for visualizing the reconstruction of SparK pretrained models, like:
@@ -94,9 +94,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
catalog
- [x] Pretraining code
+- [x] Pretraining toturial for custom CNN model ([pretrain/models/custom.py](pretrain/models/custom.py))
+- [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
- [x] Finetuning code
-- [x] Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
-- [ ] Weights & visualization playground on `Huggingface`
+- [ ] Weights & visualization playground in `huggingface`
- [ ] Weights in `timm`
@@ -128,7 +129,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
We highly recommended you to use `torch==1.10.0`, `torchvision==0.11.1`, and `timm==0.5.4` for reproduction.
Check [INSTALL.md](INSTALL.md) to install all pip dependencies.
-- **Pretraining** all models on ImageNet-1k: see [pretrain/](pretrain)
+- **Pretraining**
+ - all ResNets and ConvNeXts on ImageNet-1k: see [pretrain/](pretrain)
+ - **your own CNN models**: see [pretrain/](pretrain), especially [pretrain/models/custom.py](pretrain/models/custom.py)
+
- **Finetuning**
- all models on ImageNet: check [downstream_imagenet/](downstream_imagenet) for subsequent instructions.
diff --git a/downstream_mmdet/README.md b/downstream_mmdet/README.md
index 77d0dbf..85d7bb2 100644
--- a/downstream_mmdet/README.md
+++ b/downstream_mmdet/README.md
@@ -19,9 +19,9 @@ This `downstream_mmdet` is isolated from pre-training codes. One can treat this
-## Installation [MMDetection with commit 6a979e2](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ConvNeXt on COCO
+## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO
-We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
+We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions.
Note the COCO dataset folder should be at `downstream_mmdet/data/coco`.
diff --git a/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py b/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
index 465fe4f..566866b 100644
--- a/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
+++ b/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
@@ -23,13 +23,13 @@ _base_ = [
model = dict(
backbone=dict(
in_chans=3,
- depths=[3, 3, 27, 3], # [modified] according to tiny to base
- dims=[128, 256, 512, 1024], # [modified] according to tiny to base
- drop_path_rate=0.5, # [modified] according to tiny to base
+ depths=[3, 3, 27, 3], # [modified] according to tiny-to-base
+ dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base
+ drop_path_rate=0.5, # [modified] according to tiny-to-base
layer_scale_init_value=1.0,
out_indices=[0, 1, 2, 3],
),
- neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny to base
+ neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
diff --git a/pretrain/README.md b/pretrain/README.md
index 0b1217d..0fb8f94 100644
--- a/pretrain/README.md
+++ b/pretrain/README.md
@@ -1,17 +1,30 @@
## Preparation for ImageNet-1k pre-training
-See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
+See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
+## Tutorial for customizing your own CNN model
+
+See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed to do is:
+
+- implementing member function `get_downsample_ratio` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
+- implementing member function `get_feature_map_channels` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
+- implementing member function `forward` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
+- define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
+- add default kwargs of `your_convnet(...)` in [/pretrain/models/__init__.py](/pretrain/models/__init__.py).
+
+Then you can use `--model=your_convnet` in the pre-training script.
+
+
## Pre-training Any Model on ImageNet-1k (224x224)
-For pre-training, run [main.sh](https://github.com/keyu-tian/SparK/blob/main/pretrain/main.sh) with bash.
+For pre-training, run [/pretrain/main.sh](/pretrain/main.sh) with bash.
It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script.
We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts).
-Their names and **default values** can be found in [utils/arg_util.py line24-47](https://github.com/keyu-tian/SparK/blob/main/pretrain/utils/arg_util.py#L24).
+Their names and **default values** can be found in [/pretrain/utils/arg_util.py line24-47](/pretrain/utils/arg_util.py).
These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
Here is an example command pre-training a ResNet50 on single machine with 8 GPUs:
@@ -69,8 +82,8 @@ Add `--resume_from=path/to/still_pretraining.pth` to resume from a saved
## Regarding sparse convolution
We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardwares.
-As can be found in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
-We also define some sparse pooling or normalization layers in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py).
+As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
+We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py).
All these "sparse" layers are implemented through pytorch built-in operators.
@@ -80,10 +93,8 @@ In SparK, the mask patch size **equals to** the downsample ratio of the CNN mode
Here is the reason: when we do mask, we:
-1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [line86-87](https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py#L86), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map.
-3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [line16](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L21)) with larger resolutions .
+1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map.
+3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py)) with larger resolutions .
So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
-Note that the `forward` function of this CNN should have an arg named `hierarchy`. You can look at https://github.com/keyu-tian/SparK/blob/main/pretrain/models/convnext.py#L78 to see what `hierarchy` means and how to handle it.
-
-After that, you can simply run `main.sh` with `--hierarchy=3` and see if it works.
+See `Tutorial for customizing your own CNN model` above.
diff --git a/pretrain/decoder.py b/pretrain/decoder.py
index 61f89b5..0b31df1 100644
--- a/pretrain/decoder.py
+++ b/pretrain/decoder.py
@@ -32,12 +32,12 @@ class UNetBlock(nn.Module):
class LightDecoder(nn.Module):
- def __init__(self, up_sample_ratio, width=768, sbn=True):
+ def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
super().__init__()
self.width = width
assert is_pow2n(up_sample_ratio)
n = round(math.log2(up_sample_ratio))
- channels = [self.width // 2 ** i for i in range(n + 1)]
+ channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])])
self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True)
diff --git a/pretrain/dist.py b/pretrain/dist.py
index 0c12d84..bae9379 100644
--- a/pretrain/dist.py
+++ b/pretrain/dist.py
@@ -8,12 +8,11 @@ import os
from typing import List
from typing import Union
+import sys
import torch
import torch.distributed as tdist
import torch.multiprocessing as mp
-from torch.distributed import barrier as __barrier
-barrier = __barrier
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
__initialized = False
@@ -23,6 +22,10 @@ def initialized():
def initialize(backend='nccl'):
+ if not torch.cuda.is_available():
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
+ return
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
@@ -64,6 +67,11 @@ def is_local_master():
return __local_rank == 0
+def barrier():
+ if __initialized:
+ tdist.barrier()
+
+
def parallelize(net, syncbn=False):
if syncbn:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
diff --git a/pretrain/encoder.py b/pretrain/encoder.py
index 46dcbe9..7342e31 100644
--- a/pretrain/encoder.py
+++ b/pretrain/encoder.py
@@ -156,10 +156,10 @@ class SparseConvNeXtBlock(nn.Module):
class SparseEncoder(nn.Module):
- def __init__(self, conv_model, input_size, downsample_raito, encoder_fea_dim, sbn=False, verbose=False):
+ def __init__(self, cnn, input_size, sbn=False, verbose=False):
super(SparseEncoder, self).__init__()
- self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=conv_model, verbose=verbose, sbn=sbn)
- self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim
+ self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
+ self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels()
@staticmethod
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
@@ -204,5 +204,5 @@ class SparseEncoder(nn.Module):
del m
return oup
- def forward(self, x, hierarchy):
- return self.sp_cnn(x, hierarchy=hierarchy)
+ def forward(self, x):
+ return self.sp_cnn(x, hierarchical=True)
diff --git a/pretrain/main.py b/pretrain/main.py
index 82406e0..6fbcbdc 100644
--- a/pretrain/main.py
+++ b/pretrain/main.py
@@ -37,7 +37,7 @@ def main_pt():
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler(
- dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed,
+ dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn
)
@@ -49,7 +49,7 @@ def main_pt():
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
model_without_ddp = SparK(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
- densify_norm=args.densify_norm, sbn=args.sbn, hierarchy=args.hierarchy,
+ densify_norm=args.densify_norm, sbn=args.sbn,
).to(args.device)
print(f'[PT model] model = {model_without_ddp}\n')
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
diff --git a/pretrain/models/__init__.py b/pretrain/models/__init__.py
index 6785a17..68aad6f 100644
--- a/pretrain/models/__init__.py
+++ b/pretrain/models/__init__.py
@@ -30,38 +30,30 @@ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
-model_alias_to_fullname = {
- 'res50': 'resnet50',
- 'res101': 'resnet101',
- 'res152': 'resnet152',
- 'res200': 'resnet200',
- 'cnxS': 'convnext_small',
- 'cnxB': 'convnext_base',
- 'cnxL': 'convnext_large',
+pretrain_default_model_kwargs = {
+ 'your_convnet': dict(),
+ 'resnet50': dict(drop_path_rate=0.05),
+ 'resnet101': dict(drop_path_rate=0.08),
+ 'resnet152': dict(drop_path_rate=0.10),
+ 'resnet200': dict(drop_path_rate=0.15),
+ 'convnext_small': dict(sparse=True, drop_path_rate=0.2),
+ 'convnext_base': dict(sparse=True, drop_path_rate=0.3),
+ 'convnext_large': dict(sparse=True, drop_path_rate=0.4),
}
-model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()}
-
-
-pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel
- 'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048],
- 'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048],
- 'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048],
- 'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048],
- 'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768],
- 'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024],
- 'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536],
-}
-for v in pre_train_d.values():
- v[0]['pretrained'] = False
- v[0]['num_classes'] = 0
- v[0]['global_pool'] = ''
+for kw in pretrain_default_model_kwargs.values():
+ kw['pretrained'] = False
+ kw['num_classes'] = 0
+ kw['global_pool'] = ''
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
from encoder import SparseEncoder
- kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name]
+ kwargs = pretrain_default_model_kwargs[name]
if drop_path_rate != 0:
kwargs['drop_path_rate'] = drop_path_rate
print(f'[build_sparse_encoder] model kwargs={kwargs}')
- return SparseEncoder(create_model(name, **kwargs), input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, sbn=sbn, verbose=verbose)
+ cnn = create_model(name, **kwargs)
+
+ return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose)
+
diff --git a/pretrain/models/convnext.py b/pretrain/models/convnext.py
index a205764..6b0169e 100644
--- a/pretrain/models/convnext.py
+++ b/pretrain/models/convnext.py
@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
#
# This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
+from typing import List
import torch
import torch.nn as nn
@@ -34,7 +35,7 @@ class ConvNeXt(nn.Module):
sparse=True,
):
super().__init__()
-
+ self.dims: List[int] = dims
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
@@ -75,13 +76,19 @@ class ConvNeXt(nn.Module):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
- def forward(self, x, hierarchy=0):
- ls = []
- for i in range(4):
- x = self.downsample_layers[i](x)
- x = self.stages[i](x)
- ls.append(x if hierarchy >= 4-i else None)
- if hierarchy:
+ def get_downsample_ratio(self) -> int:
+ return 32
+
+ def get_feature_map_channels(self) -> List[int]:
+ return self.dims
+
+ def forward(self, x, hierarchical=False):
+ if hierarchical:
+ ls = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ ls.append(x)
return ls
else:
return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls)
@@ -116,17 +123,3 @@ def convnext_large(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
return model
-
-if __name__ == '__main__':
- from timm.models import create_model
- cnx = create_model('convnext_small', sparse=False)
-
- def prt(lst):
- print([tuple(t.shape) if t is not None else '(None)' for t in lst])
- with torch.no_grad():
- inp = torch.rand(2, 3, 224, 224)
- prt(cnx(inp))
- prt(cnx(inp, hierarchy=1))
- prt(cnx(inp, hierarchy=2))
- prt(cnx(inp, hierarchy=3))
- prt(cnx(inp, hierarchy=4))
diff --git a/pretrain/models/custom.py b/pretrain/models/custom.py
new file mode 100644
index 0000000..3ebbef4
--- /dev/null
+++ b/pretrain/models/custom.py
@@ -0,0 +1,89 @@
+# Copyright (c) ByteDance, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from typing import List
+from timm.models.registry import register_model
+
+
+class YourConvNet(nn.Module):
+ """
+ This is a template for your custom ConvNet.
+ It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
+ You can refer to the implementations in `pretrain\models\resnet.py` for an example.
+ """
+
+ def get_downsample_ratio(self) -> int:
+ """
+ This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
+
+ :return: the TOTAL downsample ratio of the ConvNet.
+ E.g., for a ResNet-50, this should return 32.
+ """
+ raise NotImplementedError
+
+ def get_feature_map_channels(self) -> List[int]:
+ """
+ This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
+
+ :return: a list of the number of channels of each feature map.
+ E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
+ """
+ raise NotImplementedError
+
+ def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
+ """
+ The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
+
+ :param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
+ :param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
+ :return:
+ - hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
+ - hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
+ E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
+ for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
+ """
+ raise NotImplementedError
+
+
+@register_model
+def your_convnet_small(pretrained=False, **kwargs):
+ raise NotImplementedError
+ return YourConvNet(**kwargs)
+
+
+@torch.no_grad()
+def convnet_test():
+ from timm.models import create_model
+ cnn = create_model('your_convnet_small')
+ print('get_downsample_ratio:', cnn.get_downsample_ratio())
+ print('get_feature_map_channels:', cnn.get_feature_map_channels())
+
+ downsample_ratio = cnn.get_downsample_ratio()
+ feature_map_channels = cnn.get_feature_map_channels()
+
+ # check the forward function
+ B, C, H, W = 4, 3, 224, 224
+ inp = torch.rand(B, C, H, W)
+ feats = cnn(inp, hierarchical=True)
+ assert isinstance(feats, list)
+ assert len(feats) == len(feature_map_channels)
+ print([tuple(t.shape) for t in feats])
+
+ # check the downsample ratio
+ feats = cnn(inp, hierarchical=True)
+ assert feats[-1].shape[-2] == H // downsample_ratio
+ assert feats[-1].shape[-1] == W // downsample_ratio
+
+ # check the channel number
+ for feat, ch in zip(feats, feature_map_channels):
+ assert feat.ndim == 4
+ assert feat.shape[1] == ch
+
+
+if __name__ == '__main__':
+ convnet_test()
diff --git a/pretrain/models/resnet.py b/pretrain/models/resnet.py
index dcc1ce3..62f76eb 100644
--- a/pretrain/models/resnet.py
+++ b/pretrain/models/resnet.py
@@ -3,15 +3,26 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-
-import math
+from typing import List
import torch
import torch.nn.functional as F
from timm.models.resnet import ResNet
-def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
+# hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet`
+def get_downsample_ratio(self: ResNet) -> int:
+ return 32
+
+
+# hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet`
+def get_feature_map_channels(self: ResNet) -> List[int]:
+ # `self.feature_info` is maintained by `timm`
+ return [info['num_chs'] for info in self.feature_info[1:]]
+
+
+# hack: override the forward function of `timm.models.resnet.ResNet`
+def forward(self, x, hierarchical=False):
""" this forward function is a modified version of `timm.models.resnet.ResNet.forward`
>>> ResNet.forward
"""
@@ -20,17 +31,12 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
x = self.act1(x)
x = self.maxpool(x)
- ls = []
- x = self.layer1(x)
- ls.append(x if hierarchy >= 4 else None)
- x = self.layer2(x)
- ls.append(x if hierarchy >= 3 else None)
- x = self.layer3(x)
- ls.append(x if hierarchy >= 2 else None)
- x = self.layer4(x)
- ls.append(x if hierarchy >= 1 else None)
-
- if hierarchy:
+ if hierarchical:
+ ls = []
+ x = self.layer1(x); ls.append(x)
+ x = self.layer2(x); ls.append(x)
+ x = self.layer3(x); ls.append(x)
+ x = self.layer4(x); ls.append(x)
return ls
else:
x = self.global_pool(x)
@@ -40,19 +46,39 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
return x
+ResNet.get_downsample_ratio = get_downsample_ratio
+ResNet.get_feature_map_channels = get_feature_map_channels
ResNet.forward = forward
-if __name__ == '__main__':
+@torch.no_grad()
+def convnet_test():
from timm.models import create_model
- r50 = create_model('resnet50')
+ cnn = create_model('resnet50')
+ print('get_downsample_ratio:', cnn.get_downsample_ratio())
+ print('get_feature_map_channels:', cnn.get_feature_map_channels())
- def prt(lst):
- print([tuple(t.shape) if t is not None else '(None)' for t in lst])
- with torch.no_grad():
- inp = torch.rand(2, 3, 224, 224)
- prt(r50(inp))
- prt(r50(inp, hierarchy=1))
- prt(r50(inp, hierarchy=2))
- prt(r50(inp, hierarchy=3))
- prt(r50(inp, hierarchy=4))
+ downsample_ratio = cnn.get_downsample_ratio()
+ feature_map_channels = cnn.get_feature_map_channels()
+
+ # check the forward function
+ B, C, H, W = 4, 3, 224, 224
+ inp = torch.rand(B, C, H, W)
+ feats = cnn(inp, hierarchical=True)
+ assert isinstance(feats, list)
+ assert len(feats) == len(feature_map_channels)
+ print([tuple(t.shape) for t in feats])
+
+ # check the downsample ratio
+ feats = cnn(inp, hierarchical=True)
+ assert feats[-1].shape[-2] == H // downsample_ratio
+ assert feats[-1].shape[-1] == W // downsample_ratio
+
+ # check the channel number
+ for feat, ch in zip(feats, feature_map_channels):
+ assert feat.ndim == 4
+ assert feat.shape[1] == ch
+
+
+if __name__ == '__main__':
+ convnet_test()
diff --git a/pretrain/sampler.py b/pretrain/sampler.py
index ce8a80a..3140a59 100644
--- a/pretrain/sampler.py
+++ b/pretrain/sampler.py
@@ -19,7 +19,7 @@ def worker_init_fn(worker_id):
class DistInfiniteBatchSampler(Sampler):
- def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True):
+ def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True):
assert glb_batch_size % world_size == 0
self.world_size, self.rank = world_size, rank
self.dataset_len = dataset_len
diff --git a/pretrain/spark.py b/pretrain/spark.py
index 9308fd6..c5cfa64 100644
--- a/pretrain/spark.py
+++ b/pretrain/spark.py
@@ -20,7 +20,7 @@ from decoder import LightDecoder
class SparK(nn.Module):
def __init__(
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
- mask_ratio=0.6, densify_norm='bn', sbn=False, hierarchy=4,
+ mask_ratio=0.6, densify_norm='bn', sbn=False,
):
super().__init__()
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
@@ -33,15 +33,23 @@ class SparK(nn.Module):
self.dense_decoder = dense_decoder
self.sbn = sbn
- self.hierarchy = hierarchy
+ self.hierarchy = len(sparse_encoder.enc_feat_map_chs)
self.densify_norm_str = densify_norm.lower()
self.densify_norms = nn.ModuleList()
self.densify_projs = nn.ModuleList()
self.mask_tokens = nn.ParameterList()
# build the `densify` layers
- e_width, d_width = self.sparse_encoder.fea_dim, self.dense_decoder.width
- for i in range(self.hierarchy):
+ e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width
+ e_widths: List[int]
+ for i in range(self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ...
+ e_width = e_widths.pop()
+ # create mask token
+ p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
+ trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
+ self.mask_tokens.append(p)
+
+ # create densify norm
if self.densify_norm_str == 'bn':
densify_norm = (encoder.SparseSyncBatchNorm2d if self.sbn else encoder.SparseBatchNorm2d)(e_width)
elif self.densify_norm_str == 'ln':
@@ -50,6 +58,7 @@ class SparK(nn.Module):
densify_norm = nn.Identity()
self.densify_norms.append(densify_norm)
+ # create densify proj
if i == 0 and e_width == d_width:
densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj')
@@ -59,10 +68,7 @@ class SparK(nn.Module):
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)')
self.densify_projs.append(densify_proj)
- p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
- trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
- self.mask_tokens.append(p)
- e_width //= 2
+ # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
d_width //= 2
print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')
@@ -89,7 +95,7 @@ class SparK(nn.Module):
masked_bchw = inp_bchw * active_b1hw
# step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
- fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, hierarchy=self.hierarchy)
+ fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw)
fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest
# step3. Densify: get hierarchical dense features for decoding
diff --git a/pretrain/utils/arg_util.py b/pretrain/utils/arg_util.py
index 1a8cd63..46a9865 100644
--- a/pretrain/utils/arg_util.py
+++ b/pretrain/utils/arg_util.py
@@ -15,19 +15,16 @@ import dist
class Args(Tap):
# environment
- exp_name: str
- exp_dir: str
- data_path: str
+ exp_name: str = 'your_exp_name'
+ exp_dir: str = 'your_exp_dir' # will be created if not exists
+ data_path: str = 'imagenet_data_path'
resume_from: str = '' # resume from some checkpoint.pth
- seed: int = 1
# SparK hyperparameters
- mask: float = 0.6
- hierarchy: int = 4
+ mask: float = 0.6 # mask ratio, should be in (0, 1)
# encoder hyperparameters
- model: str = 'res50'
- model_alias: str = 'res50'
+ model: str = 'resnet50'
input_size: int = 224
sbn: bool = True
@@ -70,7 +67,7 @@ class Args(Tap):
@property
def is_resnet(self):
- return 'resnet' in self.model or 'res' in self.model_alias
+ return 'resnet' in self.model
def log_epoch(self):
if not dist.is_local_master():
@@ -96,7 +93,6 @@ class Args(Tap):
def init_dist_and_get_args():
from utils import misc
- from models import model_alias_to_fullname, model_fullname_to_alias
# initialize
args = Args(explicit_bool=True).parse_args()
@@ -116,10 +112,6 @@ def init_dist_and_get_args():
misc.init_distributed_environ(exp_dir=args.exp_dir)
# update args
- if args.model in model_alias_to_fullname.keys():
- args.model = model_alias_to_fullname[args.model]
- args.model_alias = model_fullname_to_alias[args.model]
-
args.first_logging = True
args.device = dist.get_device()
args.batch_size_per_gpu = args.bs // dist.get_world_size()
@@ -137,7 +129,4 @@ def init_dist_and_get_args():
args.lr = args.base_lr * args.glb_batch_size / 256
args.wde = args.wde or args.wd
- if args.hierarchy < 1:
- args.hierarchy = 1
-
return args