**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
Replace the function `build_dataset_to_pretrain` in [line54-75 of /pretrain/utils/imagenet.py](/pretrain/utils/imagenet.py#L54-L75) to yours.
This function should return a `Dataset` object. You may use args like `args.data_path` and `args.input_size` to help build your dataset. And when runing experiment with `main.sh` you can use `--data_path=... --input_size=...` to specify them.
Note the batch size `--bs` is the total batch size of all GPU, which may also need to be tuned.
## Debug on 1 GPU (without DistributedDataParallel)
Use a small batch size `--bs=32` for avoiding OOM.
It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script.
**Note: the batch size `--bs` is the total batch size of all GPU, and the learning rate `--base_lr` is the base learning rate. The actual learning rate would be `lr * bs / 256`, as in [/pretrain/utils/arg_util.py line131](/pretrain/utils/arg_util.py#L131). Don't use `--lr` to specify a lr (would be ignored)**
1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py#L86-L87), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_size, fmap_size]`, and would be used to mask the smallest feature map.
3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py#L21)) with larger resolutions .
See [Tutorial for pretraining your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model).