[upd] 1. refactor a lot to simplify the pretraining codes; 2. add tutorial for customizing your own CNN model; 3. update some READMEs

main
keyu-tian 2 years ago
parent 46f9ad2871
commit 6ffe453fa5
  1. 14
      README.md
  2. 4
      downstream_mmdet/README.md
  3. 8
      downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
  4. 31
      pretrain/README.md
  5. 4
      pretrain/decoder.py
  6. 12
      pretrain/dist.py
  7. 10
      pretrain/encoder.py
  8. 4
      pretrain/main.py
  9. 44
      pretrain/models/__init__.py
  10. 29
      pretrain/models/convnext.py
  11. 89
      pretrain/models/custom.py
  12. 76
      pretrain/models/resnet.py
  13. 2
      pretrain/sampler.py
  14. 24
      pretrain/spark.py
  15. 23
      pretrain/utils/arg_util.py

@ -25,7 +25,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-
## 🔥 News ## 🔥 News
- On **Mar. 22nd (UTC+8 8pm; UTC+0 12am)** another livestream would be held at [极市平台-bilibili](https://live.bilibili.com/3344545). - On **Mar. 22nd (UTC+8 8pm)** another livestream would be held at 极市平台-Bilibili! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1Da4y1T7mr/)]
- The share on [TechBeat (将门创投)](https://www.techbeat.net/talk-info?id=758) is scheduled on **Mar. 16th (UTC+8 8pm)** too! [[`📹Recorded Video`](https://www.techbeat.net/talk-info?id=758)] - The share on [TechBeat (将门创投)](https://www.techbeat.net/talk-info?id=758) is scheduled on **Mar. 16th (UTC+8 8pm)** too! [[`📹Recorded Video`](https://www.techbeat.net/talk-info?id=758)]
- We are honored to be invited by Synced ("机器之心机动组 视频号" on WeChat) to give a talk about SparK on **Feb. 27th (UTC+0 11am, UTC+8 7pm)**, welcome! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1J54y1u7U3/)] - We are honored to be invited by Synced ("机器之心机动组 视频号" on WeChat) to give a talk about SparK on **Feb. 27th (UTC+0 11am, UTC+8 7pm)**, welcome! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1J54y1u7U3/)]
- This work got accepted to ICLR 2023 as a Spotlight (notable-top-25%). - This work got accepted to ICLR 2023 as a Spotlight (notable-top-25%).
@ -44,7 +44,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-
<!-- ## 📺 Video demo (we use [these ppt slides](https://github.com/keyu-tian/SparK/releases/tag/file_sharing) to make the animated video) --> <!-- ## 📺 Video demo (we use [these ppt slides](https://github.com/keyu-tian/SparK/releases/tag/file_sharing) to make the animated video) -->
<!-- https://user-images.githubusercontent.com/6366788/213662770-5f814de0-cbe8-48d9-8235-e8907fd81e0e.mp4 --> <!-- https://user-images.githubusercontent.com/6366788/213662770-5f814de0-cbe8-48d9-8235-e8907fd81e0e.mp4 -->
## 🕹 CoLab Visualization Demo ## 🕹 Colab Visualization Demo
Check [pretrain/viz_reconstruction.ipynb](pretrain/viz_reconstruction.ipynb) for visualizing the reconstruction of SparK pretrained models, like: Check [pretrain/viz_reconstruction.ipynb](pretrain/viz_reconstruction.ipynb) for visualizing the reconstruction of SparK pretrained models, like:
@ -94,9 +94,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
<summary>catalog</summary> <summary>catalog</summary>
- [x] Pretraining code - [x] Pretraining code
- [x] Pretraining toturial for custom CNN model ([pretrain/models/custom.py](pretrain/models/custom.py))
- [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
- [x] Finetuning code - [x] Finetuning code
- [x] Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb)) - [ ] Weights & visualization playground in `huggingface`
- [ ] Weights & visualization playground on `Huggingface`
- [ ] Weights in `timm` - [ ] Weights in `timm`
</details> </details>
@ -128,7 +129,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
We highly recommended you to use `torch==1.10.0`, `torchvision==0.11.1`, and `timm==0.5.4` for reproduction. We highly recommended you to use `torch==1.10.0`, `torchvision==0.11.1`, and `timm==0.5.4` for reproduction.
Check [INSTALL.md](INSTALL.md) to install all pip dependencies. Check [INSTALL.md](INSTALL.md) to install all pip dependencies.
- **Pretraining** all models on ImageNet-1k: &nbsp;see [pretrain/](pretrain) - **Pretraining**
- all ResNets and ConvNeXts on ImageNet-1k: &nbsp;see [pretrain/](pretrain)
- **your own CNN models**: &nbsp;see [pretrain/](pretrain), especially [pretrain/models/custom.py](pretrain/models/custom.py)
- **Finetuning** - **Finetuning**
- all models on ImageNet: &nbsp;check [downstream_imagenet/](downstream_imagenet) for subsequent instructions. - all models on ImageNet: &nbsp;check [downstream_imagenet/](downstream_imagenet) for subsequent instructions.

@ -19,9 +19,9 @@ This `downstream_mmdet` is isolated from pre-training codes. One can treat this
<p> <p>
## Installation [MMDetection with commit 6a979e2](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ConvNeXt on COCO ## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO
We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d). We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions. Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions.
Note the COCO dataset folder should be at `downstream_mmdet/data/coco`. Note the COCO dataset folder should be at `downstream_mmdet/data/coco`.

@ -23,13 +23,13 @@ _base_ = [
model = dict( model = dict(
backbone=dict( backbone=dict(
in_chans=3, in_chans=3,
depths=[3, 3, 27, 3], # [modified] according to tiny to base depths=[3, 3, 27, 3], # [modified] according to tiny-to-base
dims=[128, 256, 512, 1024], # [modified] according to tiny to base dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base
drop_path_rate=0.5, # [modified] according to tiny to base drop_path_rate=0.5, # [modified] according to tiny-to-base
layer_scale_init_value=1.0, layer_scale_init_value=1.0,
out_indices=[0, 1, 2, 3], out_indices=[0, 1, 2, 3],
), ),
neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny to base neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base
img_norm_cfg = dict( img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

@ -1,17 +1,30 @@
## Preparation for ImageNet-1k pre-training ## Preparation for ImageNet-1k pre-training
See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset. See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).** **Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
## Tutorial for customizing your own CNN model
See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed to do is:
- implementing member function `get_downsample_ratio` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- implementing member function `get_feature_map_channels` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- implementing member function `forward` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- add default kwargs of `your_convnet(...)` in [/pretrain/models/__init__.py](/pretrain/models/__init__.py).
Then you can use `--model=your_convnet` in the pre-training script.
## Pre-training Any Model on ImageNet-1k (224x224) ## Pre-training Any Model on ImageNet-1k (224x224)
For pre-training, run [main.sh](https://github.com/keyu-tian/SparK/blob/main/pretrain/main.sh) with bash. For pre-training, run [/pretrain/main.sh](/pretrain/main.sh) with bash.
It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script. It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script.
We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts). We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts).
Their names and **default values** can be found in [utils/arg_util.py line24-47](https://github.com/keyu-tian/SparK/blob/main/pretrain/utils/arg_util.py#L24). Their names and **default values** can be found in [/pretrain/utils/arg_util.py line24-47](/pretrain/utils/arg_util.py).
These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`. These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
Here is an example command pre-training a ResNet50 on single machine with 8 GPUs: Here is an example command pre-training a ResNet50 on single machine with 8 GPUs:
@ -69,8 +82,8 @@ Add `--resume_from=path/to/<model>still_pretraining.pth` to resume from a saved
## Regarding sparse convolution ## Regarding sparse convolution
We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardwares. We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardwares.
As can be found in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution. As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
We also define some sparse pooling or normalization layers in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py). We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py).
All these "sparse" layers are implemented through pytorch built-in operators. All these "sparse" layers are implemented through pytorch built-in operators.
@ -80,10 +93,8 @@ In SparK, the mask patch size **equals to** the downsample ratio of the CNN mode
Here is the reason: when we do mask, we: Here is the reason: when we do mask, we:
1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [line86-87](https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py#L86), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map. 1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map.
3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [line16](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L21)) with larger resolutions . 3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py)) with larger resolutions .
So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8. So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
Note that the `forward` function of this CNN should have an arg named `hierarchy`. You can look at https://github.com/keyu-tian/SparK/blob/main/pretrain/models/convnext.py#L78 to see what `hierarchy` means and how to handle it. See `Tutorial for customizing your own CNN model` above.
After that, you can simply run `main.sh` with `--hierarchy=3` and see if it works.

@ -32,12 +32,12 @@ class UNetBlock(nn.Module):
class LightDecoder(nn.Module): class LightDecoder(nn.Module):
def __init__(self, up_sample_ratio, width=768, sbn=True): def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
super().__init__() super().__init__()
self.width = width self.width = width
assert is_pow2n(up_sample_ratio) assert is_pow2n(up_sample_ratio)
n = round(math.log2(up_sample_ratio)) n = round(math.log2(up_sample_ratio))
channels = [self.width // 2 ** i for i in range(n + 1)] channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])]) self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])])
self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True) self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True)

@ -8,12 +8,11 @@ import os
from typing import List from typing import List
from typing import Union from typing import Union
import sys
import torch import torch
import torch.distributed as tdist import torch.distributed as tdist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.distributed import barrier as __barrier
barrier = __barrier
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
__initialized = False __initialized = False
@ -23,6 +22,10 @@ def initialized():
def initialize(backend='nccl'): def initialize(backend='nccl'):
if not torch.cuda.is_available():
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
return
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn') mp.set_start_method('spawn')
@ -64,6 +67,11 @@ def is_local_master():
return __local_rank == 0 return __local_rank == 0
def barrier():
if __initialized:
tdist.barrier()
def parallelize(net, syncbn=False): def parallelize(net, syncbn=False):
if syncbn: if syncbn:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)

@ -156,10 +156,10 @@ class SparseConvNeXtBlock(nn.Module):
class SparseEncoder(nn.Module): class SparseEncoder(nn.Module):
def __init__(self, conv_model, input_size, downsample_raito, encoder_fea_dim, sbn=False, verbose=False): def __init__(self, cnn, input_size, sbn=False, verbose=False):
super(SparseEncoder, self).__init__() super(SparseEncoder, self).__init__()
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=conv_model, verbose=verbose, sbn=sbn) self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels()
@staticmethod @staticmethod
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False): def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
@ -204,5 +204,5 @@ class SparseEncoder(nn.Module):
del m del m
return oup return oup
def forward(self, x, hierarchy): def forward(self, x):
return self.sp_cnn(x, hierarchy=hierarchy) return self.sp_cnn(x, hierarchical=True)

@ -37,7 +37,7 @@ def main_pt():
data_loader_train = DataLoader( data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True, dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler( batch_sampler=DistInfiniteBatchSampler(
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed, dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(), shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
), worker_init_fn=worker_init_fn ), worker_init_fn=worker_init_fn
) )
@ -49,7 +49,7 @@ def main_pt():
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn) dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
model_without_ddp = SparK( model_without_ddp = SparK(
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
densify_norm=args.densify_norm, sbn=args.sbn, hierarchy=args.hierarchy, densify_norm=args.densify_norm, sbn=args.sbn,
).to(args.device) ).to(args.device)
print(f'[PT model] model = {model_without_ddp}\n') print(f'[PT model] model = {model_without_ddp}\n')
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)

@ -30,38 +30,30 @@ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
model_alias_to_fullname = { pretrain_default_model_kwargs = {
'res50': 'resnet50', 'your_convnet': dict(),
'res101': 'resnet101', 'resnet50': dict(drop_path_rate=0.05),
'res152': 'resnet152', 'resnet101': dict(drop_path_rate=0.08),
'res200': 'resnet200', 'resnet152': dict(drop_path_rate=0.10),
'cnxS': 'convnext_small', 'resnet200': dict(drop_path_rate=0.15),
'cnxB': 'convnext_base', 'convnext_small': dict(sparse=True, drop_path_rate=0.2),
'cnxL': 'convnext_large', 'convnext_base': dict(sparse=True, drop_path_rate=0.3),
'convnext_large': dict(sparse=True, drop_path_rate=0.4),
} }
model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()} for kw in pretrain_default_model_kwargs.values():
kw['pretrained'] = False
kw['num_classes'] = 0
pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel kw['global_pool'] = ''
'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048],
'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048],
'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048],
'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048],
'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768],
'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024],
'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536],
}
for v in pre_train_d.values():
v[0]['pretrained'] = False
v[0]['num_classes'] = 0
v[0]['global_pool'] = ''
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
from encoder import SparseEncoder from encoder import SparseEncoder
kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name] kwargs = pretrain_default_model_kwargs[name]
if drop_path_rate != 0: if drop_path_rate != 0:
kwargs['drop_path_rate'] = drop_path_rate kwargs['drop_path_rate'] = drop_path_rate
print(f'[build_sparse_encoder] model kwargs={kwargs}') print(f'[build_sparse_encoder] model kwargs={kwargs}')
return SparseEncoder(create_model(name, **kwargs), input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, sbn=sbn, verbose=verbose) cnn = create_model(name, **kwargs)
return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose)

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# #
# This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py # This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -34,7 +35,7 @@ class ConvNeXt(nn.Module):
sparse=True, sparse=True,
): ):
super().__init__() super().__init__()
self.dims: List[int] = dims
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential( stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
@ -75,13 +76,19 @@ class ConvNeXt(nn.Module):
trunc_normal_(m.weight, std=.02) trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x, hierarchy=0): def get_downsample_ratio(self) -> int:
return 32
def get_feature_map_channels(self) -> List[int]:
return self.dims
def forward(self, x, hierarchical=False):
if hierarchical:
ls = [] ls = []
for i in range(4): for i in range(4):
x = self.downsample_layers[i](x) x = self.downsample_layers[i](x)
x = self.stages[i](x) x = self.stages[i](x)
ls.append(x if hierarchy >= 4-i else None) ls.append(x)
if hierarchy:
return ls return ls
else: else:
return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls) return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls)
@ -116,17 +123,3 @@ def convnext_large(pretrained=False, in_22k=False, **kwargs):
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
return model return model
if __name__ == '__main__':
from timm.models import create_model
cnx = create_model('convnext_small', sparse=False)
def prt(lst):
print([tuple(t.shape) if t is not None else '(None)' for t in lst])
with torch.no_grad():
inp = torch.rand(2, 3, 224, 224)
prt(cnx(inp))
prt(cnx(inp, hierarchy=1))
prt(cnx(inp, hierarchy=2))
prt(cnx(inp, hierarchy=3))
prt(cnx(inp, hierarchy=4))

@ -0,0 +1,89 @@
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from typing import List
from timm.models.registry import register_model
class YourConvNet(nn.Module):
"""
This is a template for your custom ConvNet.
It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
You can refer to the implementations in `pretrain\models\resnet.py` for an example.
"""
def get_downsample_ratio(self) -> int:
"""
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
:return: the TOTAL downsample ratio of the ConvNet.
E.g., for a ResNet-50, this should return 32.
"""
raise NotImplementedError
def get_feature_map_channels(self) -> List[int]:
"""
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
:return: a list of the number of channels of each feature map.
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
"""
raise NotImplementedError
def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
"""
The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
:param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
:param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
:return:
- hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
- hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
"""
raise NotImplementedError
@register_model
def your_convnet_small(pretrained=False, **kwargs):
raise NotImplementedError
return YourConvNet(**kwargs)
@torch.no_grad()
def convnet_test():
from timm.models import create_model
cnn = create_model('your_convnet_small')
print('get_downsample_ratio:', cnn.get_downsample_ratio())
print('get_feature_map_channels:', cnn.get_feature_map_channels())
downsample_ratio = cnn.get_downsample_ratio()
feature_map_channels = cnn.get_feature_map_channels()
# check the forward function
B, C, H, W = 4, 3, 224, 224
inp = torch.rand(B, C, H, W)
feats = cnn(inp, hierarchical=True)
assert isinstance(feats, list)
assert len(feats) == len(feature_map_channels)
print([tuple(t.shape) for t in feats])
# check the downsample ratio
feats = cnn(inp, hierarchical=True)
assert feats[-1].shape[-2] == H // downsample_ratio
assert feats[-1].shape[-1] == W // downsample_ratio
# check the channel number
for feat, ch in zip(feats, feature_map_channels):
assert feat.ndim == 4
assert feat.shape[1] == ch
if __name__ == '__main__':
convnet_test()

@ -3,15 +3,26 @@
# #
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from timm.models.resnet import ResNet from timm.models.resnet import ResNet
def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4 # hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet`
def get_downsample_ratio(self: ResNet) -> int:
return 32
# hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet`
def get_feature_map_channels(self: ResNet) -> List[int]:
# `self.feature_info` is maintained by `timm`
return [info['num_chs'] for info in self.feature_info[1:]]
# hack: override the forward function of `timm.models.resnet.ResNet`
def forward(self, x, hierarchical=False):
""" this forward function is a modified version of `timm.models.resnet.ResNet.forward` """ this forward function is a modified version of `timm.models.resnet.ResNet.forward`
>>> ResNet.forward >>> ResNet.forward
""" """
@ -20,17 +31,12 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
x = self.act1(x) x = self.act1(x)
x = self.maxpool(x) x = self.maxpool(x)
if hierarchical:
ls = [] ls = []
x = self.layer1(x) x = self.layer1(x); ls.append(x)
ls.append(x if hierarchy >= 4 else None) x = self.layer2(x); ls.append(x)
x = self.layer2(x) x = self.layer3(x); ls.append(x)
ls.append(x if hierarchy >= 3 else None) x = self.layer4(x); ls.append(x)
x = self.layer3(x)
ls.append(x if hierarchy >= 2 else None)
x = self.layer4(x)
ls.append(x if hierarchy >= 1 else None)
if hierarchy:
return ls return ls
else: else:
x = self.global_pool(x) x = self.global_pool(x)
@ -40,19 +46,39 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4
return x return x
ResNet.get_downsample_ratio = get_downsample_ratio
ResNet.get_feature_map_channels = get_feature_map_channels
ResNet.forward = forward ResNet.forward = forward
if __name__ == '__main__': @torch.no_grad()
def convnet_test():
from timm.models import create_model from timm.models import create_model
r50 = create_model('resnet50') cnn = create_model('resnet50')
print('get_downsample_ratio:', cnn.get_downsample_ratio())
def prt(lst): print('get_feature_map_channels:', cnn.get_feature_map_channels())
print([tuple(t.shape) if t is not None else '(None)' for t in lst])
with torch.no_grad(): downsample_ratio = cnn.get_downsample_ratio()
inp = torch.rand(2, 3, 224, 224) feature_map_channels = cnn.get_feature_map_channels()
prt(r50(inp))
prt(r50(inp, hierarchy=1)) # check the forward function
prt(r50(inp, hierarchy=2)) B, C, H, W = 4, 3, 224, 224
prt(r50(inp, hierarchy=3)) inp = torch.rand(B, C, H, W)
prt(r50(inp, hierarchy=4)) feats = cnn(inp, hierarchical=True)
assert isinstance(feats, list)
assert len(feats) == len(feature_map_channels)
print([tuple(t.shape) for t in feats])
# check the downsample ratio
feats = cnn(inp, hierarchical=True)
assert feats[-1].shape[-2] == H // downsample_ratio
assert feats[-1].shape[-1] == W // downsample_ratio
# check the channel number
for feat, ch in zip(feats, feature_map_channels):
assert feat.ndim == 4
assert feat.shape[1] == ch
if __name__ == '__main__':
convnet_test()

@ -19,7 +19,7 @@ def worker_init_fn(worker_id):
class DistInfiniteBatchSampler(Sampler): class DistInfiniteBatchSampler(Sampler):
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True): def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True):
assert glb_batch_size % world_size == 0 assert glb_batch_size % world_size == 0
self.world_size, self.rank = world_size, rank self.world_size, self.rank = world_size, rank
self.dataset_len = dataset_len self.dataset_len = dataset_len

@ -20,7 +20,7 @@ from decoder import LightDecoder
class SparK(nn.Module): class SparK(nn.Module):
def __init__( def __init__(
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder, self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
mask_ratio=0.6, densify_norm='bn', sbn=False, hierarchy=4, mask_ratio=0.6, densify_norm='bn', sbn=False,
): ):
super().__init__() super().__init__()
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
@ -33,15 +33,23 @@ class SparK(nn.Module):
self.dense_decoder = dense_decoder self.dense_decoder = dense_decoder
self.sbn = sbn self.sbn = sbn
self.hierarchy = hierarchy self.hierarchy = len(sparse_encoder.enc_feat_map_chs)
self.densify_norm_str = densify_norm.lower() self.densify_norm_str = densify_norm.lower()
self.densify_norms = nn.ModuleList() self.densify_norms = nn.ModuleList()
self.densify_projs = nn.ModuleList() self.densify_projs = nn.ModuleList()
self.mask_tokens = nn.ParameterList() self.mask_tokens = nn.ParameterList()
# build the `densify` layers # build the `densify` layers
e_width, d_width = self.sparse_encoder.fea_dim, self.dense_decoder.width e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width
for i in range(self.hierarchy): e_widths: List[int]
for i in range(self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ...
e_width = e_widths.pop()
# create mask token
p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(p)
# create densify norm
if self.densify_norm_str == 'bn': if self.densify_norm_str == 'bn':
densify_norm = (encoder.SparseSyncBatchNorm2d if self.sbn else encoder.SparseBatchNorm2d)(e_width) densify_norm = (encoder.SparseSyncBatchNorm2d if self.sbn else encoder.SparseBatchNorm2d)(e_width)
elif self.densify_norm_str == 'ln': elif self.densify_norm_str == 'ln':
@ -50,6 +58,7 @@ class SparK(nn.Module):
densify_norm = nn.Identity() densify_norm = nn.Identity()
self.densify_norms.append(densify_norm) self.densify_norms.append(densify_norm)
# create densify proj
if i == 0 and e_width == d_width: if i == 0 and e_width == d_width:
densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768 densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj') print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj')
@ -59,10 +68,7 @@ class SparK(nn.Module):
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)') print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)')
self.densify_projs.append(densify_proj) self.densify_projs.append(densify_proj)
p = nn.Parameter(torch.zeros(1, e_width, 1, 1)) # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(p)
e_width //= 2
d_width //= 2 d_width //= 2
print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}') print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')
@ -89,7 +95,7 @@ class SparK(nn.Module):
masked_bchw = inp_bchw * active_b1hw masked_bchw = inp_bchw * active_b1hw
# step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales) # step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, hierarchy=self.hierarchy) fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw)
fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest
# step3. Densify: get hierarchical dense features for decoding # step3. Densify: get hierarchical dense features for decoding

@ -15,19 +15,16 @@ import dist
class Args(Tap): class Args(Tap):
# environment # environment
exp_name: str exp_name: str = 'your_exp_name'
exp_dir: str exp_dir: str = 'your_exp_dir' # will be created if not exists
data_path: str data_path: str = 'imagenet_data_path'
resume_from: str = '' # resume from some checkpoint.pth resume_from: str = '' # resume from some checkpoint.pth
seed: int = 1
# SparK hyperparameters # SparK hyperparameters
mask: float = 0.6 mask: float = 0.6 # mask ratio, should be in (0, 1)
hierarchy: int = 4
# encoder hyperparameters # encoder hyperparameters
model: str = 'res50' model: str = 'resnet50'
model_alias: str = 'res50'
input_size: int = 224 input_size: int = 224
sbn: bool = True sbn: bool = True
@ -70,7 +67,7 @@ class Args(Tap):
@property @property
def is_resnet(self): def is_resnet(self):
return 'resnet' in self.model or 'res' in self.model_alias return 'resnet' in self.model
def log_epoch(self): def log_epoch(self):
if not dist.is_local_master(): if not dist.is_local_master():
@ -96,7 +93,6 @@ class Args(Tap):
def init_dist_and_get_args(): def init_dist_and_get_args():
from utils import misc from utils import misc
from models import model_alias_to_fullname, model_fullname_to_alias
# initialize # initialize
args = Args(explicit_bool=True).parse_args() args = Args(explicit_bool=True).parse_args()
@ -116,10 +112,6 @@ def init_dist_and_get_args():
misc.init_distributed_environ(exp_dir=args.exp_dir) misc.init_distributed_environ(exp_dir=args.exp_dir)
# update args # update args
if args.model in model_alias_to_fullname.keys():
args.model = model_alias_to_fullname[args.model]
args.model_alias = model_fullname_to_alias[args.model]
args.first_logging = True args.first_logging = True
args.device = dist.get_device() args.device = dist.get_device()
args.batch_size_per_gpu = args.bs // dist.get_world_size() args.batch_size_per_gpu = args.bs // dist.get_world_size()
@ -137,7 +129,4 @@ def init_dist_and_get_args():
args.lr = args.base_lr * args.glb_batch_size / 256 args.lr = args.base_lr * args.glb_batch_size / 256
args.wde = args.wde or args.wd args.wde = args.wde or args.wd
if args.hierarchy < 1:
args.hierarchy = 1
return args return args

Loading…
Cancel
Save