# SparK: the first successful BERT/MAE-style pretraining on *any* convolutional networks [![Reddit](https://img.shields.io/badge/Reddit-🔥%20120k%20views-b31b1b.svg?style=social&logo=reddit)](https://www.reddit.com/r/MachineLearning/comments/10ix0l1/r_iclr2023_spotlight_the_first_bertstyle/) [![Twitter](https://img.shields.io/badge/Twitter-🔥%2020k%2B120k%20views-b31b1b.svg?style=social&logo=twitter)](https://twitter.com/keyutian/status/1616606179144380422)
This is the official implementation of ICLR paper [Designing BERT for Convolutional Networks: ***Spar***se and Hierarchical Mas***k***ed Modeling](https://arxiv.org/abs/2301.03580), which can pretrain **any CNN** (e.g., ResNet) in a **BERT-style self-supervised** manner.
We've tried our best to make the codebase clean, short, easy to read, state-of-the-art, and only rely on minimal dependencies.
https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33-ab21-4785b8a7833c.mp4
We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that shows the "mask pattern vanishing" issue of dense conv layers. ## What's new here? ### 🔥 Pretrained CNN beats pretrained Swin-Transformer:
### 🔥 After SparK pretraining, smaller models can beat un-pretrained larger models:
### 🔥 All models can benefit, showing a scaling behavior:
### 🔥 Generative self-supervised pretraining surpasses contrastive learning:
#### See our [paper](https://arxiv.org/pdf/2301.03580.pdf) for more analysis, discussions, and evaluations.
## Todo list
catalog
- [x] Pretraining code
- [x] Pretraining toturial for customized CNN model ([Tutorial for pretraining your own CNN model](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model))
- [x] Pretraining toturial for customized dataset ([Tutorial for pretraining your own dataset](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-dataset))
- [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
- [x] Finetuning code
- [ ] Weights & visualization playground in `huggingface`
- [ ] Weights in `timm`
Pretrained weights (with SparK's UNet-style decoder; can be used to reconstruct images)
| arch. | reso. | acc@1 | #params | flops | weights (self-supervised, with SparK's decoder) |
|:----------:|:-----:|:-----:|:-------:|:------:|:------------------------------------------------------------------------------------------------------------------------------------------|
| ResNet50 | 224 | 80.6 | 26M | 4.1G | [res50_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1STt3w3e5q9eCPZa8VzcJj1zG6p3jLeSF/view?usp=share_link) |
| ResNet101 | 224 | 82.2 | 45M | 7.9G | [res101_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1GjN48LKtlop2YQre6---7ViCWO-3C0yr/view?usp=share_link) |
| ResNet152 | 224 | 82.7 | 60M | 11.6G | [res152_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1U3Cd94j4ZHfYR2dUjWmsEWfjP6Opx4oo/view?usp=share_link) |
| ResNet200 | 224 | 83.1 | 65M | 15.1G | [res200_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/13AFSqvIr0v-2hmb4DzVza45t_lhf2CnD/view?usp=share_link) |
| ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | [cnxS224_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link) |
| ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | [cnxL384_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link) |
## Installation & Running
We highly recommended you to use `torch==1.10.0`, `torchvision==0.11.1`, and `timm==0.5.4` for reproduction.
Check [INSTALL.md](INSTALL.md) to install all pip dependencies.
- **Loading pretrained model weights in 3 lines**
```python3
# download our weights `resnet50_1kpretrained_timm_style.pth` first
import torch, timm
res50, state = timm.create_model('resnet50'), torch.load('resnet50_1kpretrained_timm_style.pth', 'cpu')
res50.load_state_dict(state.get('module', state), strict=False) # just in case the model weights are actually saved in state['module']
```
- **Pretraining**
- any ResNet or ConvNeXt on ImageNet-1k: see [pretrain/](pretrain)
- **your own CNN model**: see [pretrain/](pretrain), especially [pretrain/models/custom.py](pretrain/models/custom.py)
- **Finetuning**
- any ResNet or ConvNeXt on ImageNet-1k: check [downstream_imagenet/](downstream_imagenet) for subsequent instructions.
- ResNets on COCO: see [downstream_d2/](downstream_d2)
- ConvNeXts on COCO: see [downstream_mmdet/](downstream_mmdet)
## Acknowledgement
We referred to these useful codebases:
- [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MAE](https://github.com/facebookresearch/mae), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt)
- [timm](https://github.com/rwightman/pytorch-image-models), [MoCoV2](https://github.com/facebookresearch/moco), [Detectron2](https://github.com/facebookresearch/detectron2), [MMDetection](https://github.com/open-mmlab/mmdetection)
## License
This project is under the MIT license. See [LICENSE](LICENSE) for more details.
## Citation
If you found this project useful, you can kindly give us a star ⭐, or cite us in your work 📖:
```
@Article{tian2023designing,
author = {Keyu Tian and Yi Jiang and Qishuai Diao and Chen Lin and Liwei Wang and Zehuan Yuan},
title = {Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling},
journal = {arXiv:2301.03580},
year = {2023},
}
```