From 6ffe453fa5b14bac54d0628ee214be8d1a5d18b1 Mon Sep 17 00:00:00 2001 From: keyu-tian Date: Tue, 4 Apr 2023 17:13:19 +0800 Subject: [PATCH] [upd] 1. refactor a lot to simplify the pretraining codes; 2. add tutorial for customizing your own CNN model; 3. update some READMEs --- README.md | 14 +-- downstream_mmdet/README.md | 4 +- ...dow7_mstrain_480-800_adamw_3x_coco_in1k.py | 8 +- pretrain/README.md | 31 ++++--- pretrain/decoder.py | 4 +- pretrain/dist.py | 12 ++- pretrain/encoder.py | 10 +-- pretrain/main.py | 4 +- pretrain/models/__init__.py | 44 ++++----- pretrain/models/convnext.py | 37 ++++---- pretrain/models/custom.py | 89 +++++++++++++++++++ pretrain/models/resnet.py | 76 ++++++++++------ pretrain/sampler.py | 2 +- pretrain/spark.py | 24 +++-- pretrain/utils/arg_util.py | 23 ++--- 15 files changed, 250 insertions(+), 132 deletions(-) create mode 100644 pretrain/models/custom.py diff --git a/README.md b/README.md index 22d4a58..d345199 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33- ## 🔥 News -- On **Mar. 22nd (UTC+8 8pm; UTC+0 12am)** another livestream would be held at [极市平台-bilibili](https://live.bilibili.com/3344545). +- On **Mar. 22nd (UTC+8 8pm)** another livestream would be held at 极市平台-Bilibili! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1Da4y1T7mr/)] - The share on [TechBeat (将门创投)](https://www.techbeat.net/talk-info?id=758) is scheduled on **Mar. 16th (UTC+8 8pm)** too! [[`📹Recorded Video`](https://www.techbeat.net/talk-info?id=758)] - We are honored to be invited by Synced ("机器之心机动组 视频号" on WeChat) to give a talk about SparK on **Feb. 27th (UTC+0 11am, UTC+8 7pm)**, welcome! [[`📹Recorded Video`](https://www.bilibili.com/video/BV1J54y1u7U3/)] - This work got accepted to ICLR 2023 as a Spotlight (notable-top-25%). @@ -44,7 +44,7 @@ https://user-images.githubusercontent.com/39692511/226858919-dd4ccf7e-a5ba-4a33- -## 🕹️ CoLab Visualization Demo +## 🕹️ Colab Visualization Demo Check [pretrain/viz_reconstruction.ipynb](pretrain/viz_reconstruction.ipynb) for visualizing the reconstruction of SparK pretrained models, like: @@ -94,9 +94,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show catalog - [x] Pretraining code +- [x] Pretraining toturial for custom CNN model ([pretrain/models/custom.py](pretrain/models/custom.py)) +- [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb)) - [x] Finetuning code -- [x] Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb)) -- [ ] Weights & visualization playground on `Huggingface` +- [ ] Weights & visualization playground in `huggingface` - [ ] Weights in `timm` @@ -128,7 +129,10 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show We highly recommended you to use `torch==1.10.0`, `torchvision==0.11.1`, and `timm==0.5.4` for reproduction. Check [INSTALL.md](INSTALL.md) to install all pip dependencies. -- **Pretraining** all models on ImageNet-1k:  see [pretrain/](pretrain) +- **Pretraining** + - all ResNets and ConvNeXts on ImageNet-1k:  see [pretrain/](pretrain) + - **your own CNN models**:  see [pretrain/](pretrain), especially [pretrain/models/custom.py](pretrain/models/custom.py) + - **Finetuning** - all models on ImageNet:  check [downstream_imagenet/](downstream_imagenet) for subsequent instructions. diff --git a/downstream_mmdet/README.md b/downstream_mmdet/README.md index 77d0dbf..85d7bb2 100644 --- a/downstream_mmdet/README.md +++ b/downstream_mmdet/README.md @@ -19,9 +19,9 @@ This `downstream_mmdet` is isolated from pre-training codes. One can treat this

-## Installation [MMDetection with commit 6a979e2](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ConvNeXt on COCO +## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO -We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d). +We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d). Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions. Note the COCO dataset folder should be at `downstream_mmdet/data/coco`. diff --git a/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py b/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py index 465fe4f..566866b 100644 --- a/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py +++ b/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py @@ -23,13 +23,13 @@ _base_ = [ model = dict( backbone=dict( in_chans=3, - depths=[3, 3, 27, 3], # [modified] according to tiny to base - dims=[128, 256, 512, 1024], # [modified] according to tiny to base - drop_path_rate=0.5, # [modified] according to tiny to base + depths=[3, 3, 27, 3], # [modified] according to tiny-to-base + dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base + drop_path_rate=0.5, # [modified] according to tiny-to-base layer_scale_init_value=1.0, out_indices=[0, 1, 2, 3], ), - neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny to base + neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) diff --git a/pretrain/README.md b/pretrain/README.md index 0b1217d..0fb8f94 100644 --- a/pretrain/README.md +++ b/pretrain/README.md @@ -1,17 +1,30 @@ ## Preparation for ImageNet-1k pre-training -See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset. +See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset. **Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).** +## Tutorial for customizing your own CNN model + +See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed to do is: + +- implementing member function `get_downsample_ratio` in [/pretrain/models/custom.py](/pretrain/models/custom.py). +- implementing member function `get_feature_map_channels` in [/pretrain/models/custom.py](/pretrain/models/custom.py). +- implementing member function `forward` in [/pretrain/models/custom.py](/pretrain/models/custom.py). +- define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py](/pretrain/models/custom.py). +- add default kwargs of `your_convnet(...)` in [/pretrain/models/__init__.py](/pretrain/models/__init__.py). + +Then you can use `--model=your_convnet` in the pre-training script. + + ## Pre-training Any Model on ImageNet-1k (224x224) -For pre-training, run [main.sh](https://github.com/keyu-tian/SparK/blob/main/pretrain/main.sh) with bash. +For pre-training, run [/pretrain/main.sh](/pretrain/main.sh) with bash. It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script. We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts). -Their names and **default values** can be found in [utils/arg_util.py line24-47](https://github.com/keyu-tian/SparK/blob/main/pretrain/utils/arg_util.py#L24). +Their names and **default values** can be found in [/pretrain/utils/arg_util.py line24-47](/pretrain/utils/arg_util.py). These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`. Here is an example command pre-training a ResNet50 on single machine with 8 GPUs: @@ -69,8 +82,8 @@ Add `--resume_from=path/to/still_pretraining.pth` to resume from a saved ## Regarding sparse convolution We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardwares. -As can be found in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution. -We also define some sparse pooling or normalization layers in [encoder.py](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py). +As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution. +We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py). All these "sparse" layers are implemented through pytorch built-in operators. @@ -80,10 +93,8 @@ In SparK, the mask patch size **equals to** the downsample ratio of the CNN mode Here is the reason: when we do mask, we: -1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [line86-87](https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py#L86), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map. -3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [line16](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](https://github.com/keyu-tian/SparK/blob/main/pretrain/encoder.py#L21)) with larger resolutions . +1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map. +3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py)) with larger resolutions . So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8. -Note that the `forward` function of this CNN should have an arg named `hierarchy`. You can look at https://github.com/keyu-tian/SparK/blob/main/pretrain/models/convnext.py#L78 to see what `hierarchy` means and how to handle it. - -After that, you can simply run `main.sh` with `--hierarchy=3` and see if it works. +See `Tutorial for customizing your own CNN model` above. diff --git a/pretrain/decoder.py b/pretrain/decoder.py index 61f89b5..0b31df1 100644 --- a/pretrain/decoder.py +++ b/pretrain/decoder.py @@ -32,12 +32,12 @@ class UNetBlock(nn.Module): class LightDecoder(nn.Module): - def __init__(self, up_sample_ratio, width=768, sbn=True): + def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule super().__init__() self.width = width assert is_pow2n(up_sample_ratio) n = round(math.log2(up_sample_ratio)) - channels = [self.width // 2 ** i for i in range(n + 1)] + channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])]) self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True) diff --git a/pretrain/dist.py b/pretrain/dist.py index 0c12d84..bae9379 100644 --- a/pretrain/dist.py +++ b/pretrain/dist.py @@ -8,12 +8,11 @@ import os from typing import List from typing import Union +import sys import torch import torch.distributed as tdist import torch.multiprocessing as mp -from torch.distributed import barrier as __barrier -barrier = __barrier __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' __initialized = False @@ -23,6 +22,10 @@ def initialized(): def initialize(backend='nccl'): + if not torch.cuda.is_available(): + print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) + return + # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') @@ -64,6 +67,11 @@ def is_local_master(): return __local_rank == 0 +def barrier(): + if __initialized: + tdist.barrier() + + def parallelize(net, syncbn=False): if syncbn: net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) diff --git a/pretrain/encoder.py b/pretrain/encoder.py index 46dcbe9..7342e31 100644 --- a/pretrain/encoder.py +++ b/pretrain/encoder.py @@ -156,10 +156,10 @@ class SparseConvNeXtBlock(nn.Module): class SparseEncoder(nn.Module): - def __init__(self, conv_model, input_size, downsample_raito, encoder_fea_dim, sbn=False, verbose=False): + def __init__(self, cnn, input_size, sbn=False, verbose=False): super(SparseEncoder, self).__init__() - self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=conv_model, verbose=verbose, sbn=sbn) - self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim + self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn) + self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels() @staticmethod def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False): @@ -204,5 +204,5 @@ class SparseEncoder(nn.Module): del m return oup - def forward(self, x, hierarchy): - return self.sp_cnn(x, hierarchy=hierarchy) + def forward(self, x): + return self.sp_cnn(x, hierarchical=True) diff --git a/pretrain/main.py b/pretrain/main.py index 82406e0..6fbcbdc 100644 --- a/pretrain/main.py +++ b/pretrain/main.py @@ -37,7 +37,7 @@ def main_pt(): data_loader_train = DataLoader( dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True, batch_sampler=DistInfiniteBatchSampler( - dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, seed=args.seed, + dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(), ), worker_init_fn=worker_init_fn ) @@ -49,7 +49,7 @@ def main_pt(): dec = LightDecoder(enc.downsample_raito, sbn=args.sbn) model_without_ddp = SparK( sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask, - densify_norm=args.densify_norm, sbn=args.sbn, hierarchy=args.hierarchy, + densify_norm=args.densify_norm, sbn=args.sbn, ).to(args.device) print(f'[PT model] model = {model_without_ddp}\n') model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) diff --git a/pretrain/models/__init__.py b/pretrain/models/__init__.py index 6785a17..68aad6f 100644 --- a/pretrain/models/__init__.py +++ b/pretrain/models/__init__.py @@ -30,38 +30,30 @@ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' -model_alias_to_fullname = { - 'res50': 'resnet50', - 'res101': 'resnet101', - 'res152': 'resnet152', - 'res200': 'resnet200', - 'cnxS': 'convnext_small', - 'cnxB': 'convnext_base', - 'cnxL': 'convnext_large', +pretrain_default_model_kwargs = { + 'your_convnet': dict(), + 'resnet50': dict(drop_path_rate=0.05), + 'resnet101': dict(drop_path_rate=0.08), + 'resnet152': dict(drop_path_rate=0.10), + 'resnet200': dict(drop_path_rate=0.15), + 'convnext_small': dict(sparse=True, drop_path_rate=0.2), + 'convnext_base': dict(sparse=True, drop_path_rate=0.3), + 'convnext_large': dict(sparse=True, drop_path_rate=0.4), } -model_fullname_to_alias = {v: k for k, v in model_alias_to_fullname.items()} - - -pre_train_d = { # default drop_path_rate, num of para, FLOPs, downsample_ratio, num of channel - 'resnet50': [dict(drop_path_rate=0.05), 25.6, 4.1, 32, 2048], - 'resnet101': [dict(drop_path_rate=0.08), 44.5, 7.9, 32, 2048], - 'resnet152': [dict(drop_path_rate=0.10), 60.2, 11.6, 32, 2048], - 'resnet200': [dict(drop_path_rate=0.15), 64.7, 15.1, 32, 2048], - 'convnext_small': [dict(sparse=True, drop_path_rate=0.2), 50.0, 8.7, 32, 768], - 'convnext_base': [dict(sparse=True, drop_path_rate=0.3), 89.0, 15.4, 32, 1024], - 'convnext_large': [dict(sparse=True, drop_path_rate=0.4), 198.0, 34.4, 32, 1536], -} -for v in pre_train_d.values(): - v[0]['pretrained'] = False - v[0]['num_classes'] = 0 - v[0]['global_pool'] = '' +for kw in pretrain_default_model_kwargs.values(): + kw['pretrained'] = False + kw['num_classes'] = 0 + kw['global_pool'] = '' def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): from encoder import SparseEncoder - kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name] + kwargs = pretrain_default_model_kwargs[name] if drop_path_rate != 0: kwargs['drop_path_rate'] = drop_path_rate print(f'[build_sparse_encoder] model kwargs={kwargs}') - return SparseEncoder(create_model(name, **kwargs), input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, sbn=sbn, verbose=verbose) + cnn = create_model(name, **kwargs) + + return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose) + diff --git a/pretrain/models/convnext.py b/pretrain/models/convnext.py index a205764..6b0169e 100644 --- a/pretrain/models/convnext.py +++ b/pretrain/models/convnext.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # # This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py +from typing import List import torch import torch.nn as nn @@ -34,7 +35,7 @@ class ConvNeXt(nn.Module): sparse=True, ): super().__init__() - + self.dims: List[int] = dims self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), @@ -75,13 +76,19 @@ class ConvNeXt(nn.Module): trunc_normal_(m.weight, std=.02) nn.init.constant_(m.bias, 0) - def forward(self, x, hierarchy=0): - ls = [] - for i in range(4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - ls.append(x if hierarchy >= 4-i else None) - if hierarchy: + def get_downsample_ratio(self) -> int: + return 32 + + def get_feature_map_channels(self) -> List[int]: + return self.dims + + def forward(self, x, hierarchical=False): + if hierarchical: + ls = [] + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + ls.append(x) return ls else: return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls) @@ -116,17 +123,3 @@ def convnext_large(pretrained=False, in_22k=False, **kwargs): model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) return model - -if __name__ == '__main__': - from timm.models import create_model - cnx = create_model('convnext_small', sparse=False) - - def prt(lst): - print([tuple(t.shape) if t is not None else '(None)' for t in lst]) - with torch.no_grad(): - inp = torch.rand(2, 3, 224, 224) - prt(cnx(inp)) - prt(cnx(inp, hierarchy=1)) - prt(cnx(inp, hierarchy=2)) - prt(cnx(inp, hierarchy=3)) - prt(cnx(inp, hierarchy=4)) diff --git a/pretrain/models/custom.py b/pretrain/models/custom.py new file mode 100644 index 0000000..3ebbef4 --- /dev/null +++ b/pretrain/models/custom.py @@ -0,0 +1,89 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from typing import List +from timm.models.registry import register_model + + +class YourConvNet(nn.Module): + """ + This is a template for your custom ConvNet. + It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`. + You can refer to the implementations in `pretrain\models\resnet.py` for an example. + """ + + def get_downsample_ratio(self) -> int: + """ + This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). + + :return: the TOTAL downsample ratio of the ConvNet. + E.g., for a ResNet-50, this should return 32. + """ + raise NotImplementedError + + def get_feature_map_channels(self) -> List[int]: + """ + This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). + + :return: a list of the number of channels of each feature map. + E.g., for a ResNet-50, this should return [256, 512, 1024, 2048]. + """ + raise NotImplementedError + + def forward(self, inp_bchw: torch.Tensor, hierarchical=False): + """ + The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`). + + :param inp_bchw: input image tensor, shape: (batch_size, channels, height, width). + :param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical). + :return: + - hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes). + - hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`. + E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map]. + for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)] + """ + raise NotImplementedError + + +@register_model +def your_convnet_small(pretrained=False, **kwargs): + raise NotImplementedError + return YourConvNet(**kwargs) + + +@torch.no_grad() +def convnet_test(): + from timm.models import create_model + cnn = create_model('your_convnet_small') + print('get_downsample_ratio:', cnn.get_downsample_ratio()) + print('get_feature_map_channels:', cnn.get_feature_map_channels()) + + downsample_ratio = cnn.get_downsample_ratio() + feature_map_channels = cnn.get_feature_map_channels() + + # check the forward function + B, C, H, W = 4, 3, 224, 224 + inp = torch.rand(B, C, H, W) + feats = cnn(inp, hierarchical=True) + assert isinstance(feats, list) + assert len(feats) == len(feature_map_channels) + print([tuple(t.shape) for t in feats]) + + # check the downsample ratio + feats = cnn(inp, hierarchical=True) + assert feats[-1].shape[-2] == H // downsample_ratio + assert feats[-1].shape[-1] == W // downsample_ratio + + # check the channel number + for feat, ch in zip(feats, feature_map_channels): + assert feat.ndim == 4 + assert feat.shape[1] == ch + + +if __name__ == '__main__': + convnet_test() diff --git a/pretrain/models/resnet.py b/pretrain/models/resnet.py index dcc1ce3..62f76eb 100644 --- a/pretrain/models/resnet.py +++ b/pretrain/models/resnet.py @@ -3,15 +3,26 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - -import math +from typing import List import torch import torch.nn.functional as F from timm.models.resnet import ResNet -def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4 +# hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet` +def get_downsample_ratio(self: ResNet) -> int: + return 32 + + +# hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet` +def get_feature_map_channels(self: ResNet) -> List[int]: + # `self.feature_info` is maintained by `timm` + return [info['num_chs'] for info in self.feature_info[1:]] + + +# hack: override the forward function of `timm.models.resnet.ResNet` +def forward(self, x, hierarchical=False): """ this forward function is a modified version of `timm.models.resnet.ResNet.forward` >>> ResNet.forward """ @@ -20,17 +31,12 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4 x = self.act1(x) x = self.maxpool(x) - ls = [] - x = self.layer1(x) - ls.append(x if hierarchy >= 4 else None) - x = self.layer2(x) - ls.append(x if hierarchy >= 3 else None) - x = self.layer3(x) - ls.append(x if hierarchy >= 2 else None) - x = self.layer4(x) - ls.append(x if hierarchy >= 1 else None) - - if hierarchy: + if hierarchical: + ls = [] + x = self.layer1(x); ls.append(x) + x = self.layer2(x); ls.append(x) + x = self.layer3(x); ls.append(x) + x = self.layer4(x); ls.append(x) return ls else: x = self.global_pool(x) @@ -40,19 +46,39 @@ def forward(self, x, hierarchy=0): # hierarchy: 0 or 1 or 2 or 3 or 4 return x +ResNet.get_downsample_ratio = get_downsample_ratio +ResNet.get_feature_map_channels = get_feature_map_channels ResNet.forward = forward -if __name__ == '__main__': +@torch.no_grad() +def convnet_test(): from timm.models import create_model - r50 = create_model('resnet50') + cnn = create_model('resnet50') + print('get_downsample_ratio:', cnn.get_downsample_ratio()) + print('get_feature_map_channels:', cnn.get_feature_map_channels()) - def prt(lst): - print([tuple(t.shape) if t is not None else '(None)' for t in lst]) - with torch.no_grad(): - inp = torch.rand(2, 3, 224, 224) - prt(r50(inp)) - prt(r50(inp, hierarchy=1)) - prt(r50(inp, hierarchy=2)) - prt(r50(inp, hierarchy=3)) - prt(r50(inp, hierarchy=4)) + downsample_ratio = cnn.get_downsample_ratio() + feature_map_channels = cnn.get_feature_map_channels() + + # check the forward function + B, C, H, W = 4, 3, 224, 224 + inp = torch.rand(B, C, H, W) + feats = cnn(inp, hierarchical=True) + assert isinstance(feats, list) + assert len(feats) == len(feature_map_channels) + print([tuple(t.shape) for t in feats]) + + # check the downsample ratio + feats = cnn(inp, hierarchical=True) + assert feats[-1].shape[-2] == H // downsample_ratio + assert feats[-1].shape[-1] == W // downsample_ratio + + # check the channel number + for feat, ch in zip(feats, feature_map_channels): + assert feat.ndim == 4 + assert feat.shape[1] == ch + + +if __name__ == '__main__': + convnet_test() diff --git a/pretrain/sampler.py b/pretrain/sampler.py index ce8a80a..3140a59 100644 --- a/pretrain/sampler.py +++ b/pretrain/sampler.py @@ -19,7 +19,7 @@ def worker_init_fn(worker_id): class DistInfiniteBatchSampler(Sampler): - def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True): + def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True): assert glb_batch_size % world_size == 0 self.world_size, self.rank = world_size, rank self.dataset_len = dataset_len diff --git a/pretrain/spark.py b/pretrain/spark.py index 9308fd6..c5cfa64 100644 --- a/pretrain/spark.py +++ b/pretrain/spark.py @@ -20,7 +20,7 @@ from decoder import LightDecoder class SparK(nn.Module): def __init__( self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder, - mask_ratio=0.6, densify_norm='bn', sbn=False, hierarchy=4, + mask_ratio=0.6, densify_norm='bn', sbn=False, ): super().__init__() input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito @@ -33,15 +33,23 @@ class SparK(nn.Module): self.dense_decoder = dense_decoder self.sbn = sbn - self.hierarchy = hierarchy + self.hierarchy = len(sparse_encoder.enc_feat_map_chs) self.densify_norm_str = densify_norm.lower() self.densify_norms = nn.ModuleList() self.densify_projs = nn.ModuleList() self.mask_tokens = nn.ParameterList() # build the `densify` layers - e_width, d_width = self.sparse_encoder.fea_dim, self.dense_decoder.width - for i in range(self.hierarchy): + e_widths, d_width = self.sparse_encoder.enc_feat_map_chs, self.dense_decoder.width + e_widths: List[int] + for i in range(self.hierarchy): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ... + e_width = e_widths.pop() + # create mask token + p = nn.Parameter(torch.zeros(1, e_width, 1, 1)) + trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02) + self.mask_tokens.append(p) + + # create densify norm if self.densify_norm_str == 'bn': densify_norm = (encoder.SparseSyncBatchNorm2d if self.sbn else encoder.SparseBatchNorm2d)(e_width) elif self.densify_norm_str == 'ln': @@ -50,6 +58,7 @@ class SparK(nn.Module): densify_norm = nn.Identity() self.densify_norms.append(densify_norm) + # create densify proj if i == 0 and e_width == d_width: densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768 print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj') @@ -59,10 +68,7 @@ class SparK(nn.Module): print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)') self.densify_projs.append(densify_proj) - p = nn.Parameter(torch.zeros(1, e_width, 1, 1)) - trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02) - self.mask_tokens.append(p) - e_width //= 2 + # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule d_width //= 2 print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}') @@ -89,7 +95,7 @@ class SparK(nn.Module): masked_bchw = inp_bchw * active_b1hw # step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales) - fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, hierarchy=self.hierarchy) + fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw) fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest # step3. Densify: get hierarchical dense features for decoding diff --git a/pretrain/utils/arg_util.py b/pretrain/utils/arg_util.py index 1a8cd63..46a9865 100644 --- a/pretrain/utils/arg_util.py +++ b/pretrain/utils/arg_util.py @@ -15,19 +15,16 @@ import dist class Args(Tap): # environment - exp_name: str - exp_dir: str - data_path: str + exp_name: str = 'your_exp_name' + exp_dir: str = 'your_exp_dir' # will be created if not exists + data_path: str = 'imagenet_data_path' resume_from: str = '' # resume from some checkpoint.pth - seed: int = 1 # SparK hyperparameters - mask: float = 0.6 - hierarchy: int = 4 + mask: float = 0.6 # mask ratio, should be in (0, 1) # encoder hyperparameters - model: str = 'res50' - model_alias: str = 'res50' + model: str = 'resnet50' input_size: int = 224 sbn: bool = True @@ -70,7 +67,7 @@ class Args(Tap): @property def is_resnet(self): - return 'resnet' in self.model or 'res' in self.model_alias + return 'resnet' in self.model def log_epoch(self): if not dist.is_local_master(): @@ -96,7 +93,6 @@ class Args(Tap): def init_dist_and_get_args(): from utils import misc - from models import model_alias_to_fullname, model_fullname_to_alias # initialize args = Args(explicit_bool=True).parse_args() @@ -116,10 +112,6 @@ def init_dist_and_get_args(): misc.init_distributed_environ(exp_dir=args.exp_dir) # update args - if args.model in model_alias_to_fullname.keys(): - args.model = model_alias_to_fullname[args.model] - args.model_alias = model_fullname_to_alias[args.model] - args.first_logging = True args.device = dist.get_device() args.batch_size_per_gpu = args.bs // dist.get_world_size() @@ -137,7 +129,4 @@ def init_dist_and_get_args(): args.lr = args.base_lr * args.glb_batch_size / 256 args.wde = args.wde or args.wd - if args.hierarchy < 1: - args.hierarchy = 1 - return args