[upd] add support for 1GPU debug and customized dataset

main
keyu-tian 1 year ago
parent cba9d26a66
commit 36aa8210b2
  1. 3
      README.md
  2. 44
      pretrain/README.md
  3. 46
      pretrain/dist.py
  4. 18
      pretrain/main.py
  5. 2
      pretrain/utils/arg_util.py
  6. 17
      pretrain/utils/imagenet.py

@ -95,7 +95,8 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
<summary>catalog</summary>
- [x] Pretraining code
- [x] Pretraining toturial for custom CNN model ([pretrain/models/custom.py](pretrain/models/custom.py))
- [x] Pretraining toturial for customized CNN model ([Tutorial for pretraining your own CNN model](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model))
- [x] Pretraining toturial for customized dataset ([Tutorial for pretraining your own dataset](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-dataset))
- [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
- [x] Finetuning code
- [ ] Weights & visualization playground in `huggingface`

@ -1,11 +1,11 @@
## Preparation for ImageNet-1k pre-training
## Preparation for ImageNet-1k pretraining
See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
## Tutorial for customizing your own CNN model
## Tutorial for pretraining your own CNN model
See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed to do is:
@ -15,19 +15,37 @@ See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed
- define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py line54](/pretrain/models/custom.py#L53-L54).
- add default kwargs of `your_convnet(...)` in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34).
Then you can use `--model=your_convnet` in the pre-training script.
Then you can use `--model=your_convnet` in the pretraining script.
## Pre-training Any Model on ImageNet-1k (224x224)
## Tutorial for pretraining your own dataset
For pre-training, run [/pretrain/main.sh](/pretrain/main.sh) with bash.
Replace the function `build_dataset_to_pretrain` in [line54-75 of /pretrain/utils/imagenet.py](/pretrain/utils/imagenet.py#L54-L75) to yours.
This function should return a `Dataset` object. You may use args like `args.data_path` and `args.input_size` to help build your dataset. And when runing experiment with `main.sh` you can use `--data_path=... --input_size=...` to specify them.
Note the batch size `--bs` is the total batch size of all GPU, which may also need to be tuned.
## Debug on 1 GPU (without DistributedDataParallel)
Use a small batch size `--bs=32` for avoiding OOM.
```shell script
python3 main.py --exp_name=debug --data_path=/path/to/imagenet --model=resnet50 --bs=32
```
## Pretraining Any Model on ImageNet-1k (224x224)
For pretraining, run [/pretrain/main.sh](/pretrain/main.sh) with bash.
It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script.
We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts).
We use the **same** pretraining configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts).
Their names and **default values** can be found in [/pretrain/utils/arg_util.py line23-44](/pretrain/utils/arg_util.py#L23-L44).
These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
Here is an example command pre-training a ResNet50 on single machine with 8 GPUs:
**Note: the batch size `--bs` is the total batch size of all GPU, and the learning rate `--lr` is the base learning rate. The actual learning rate would be `lr * bs / 256`, as in [/pretrain/utils/arg_util.py line131](/pretrain/utils/arg_util.py#L131).**
Here is an example command pretraining a ResNet50 on single machine with 8 GPUs (we use DistributedDataParallel):
```shell script
$ cd /path/to/SparK/pretrain
$ bash ./main.sh <experiment_name> \
@ -44,9 +62,9 @@ For multiple machines, change the `--num_nodes` to your count, and plus these ar
Note the `<experiment_name>` is the name of your experiment, which would be used to create an output directory named `output_<experiment_name>`.
## Pre-training ConvNeXt-Large on ImageNet-1k (384x384)
## Pretraining ConvNeXt-Large on ImageNet-1k (384x384)
For pre-training with resolution 384, we use a larger mask ratio (0.75), a smaller batch size (2048), and a larger learning rate (4e-4):
For pretraining with resolution 384, we use a larger mask ratio (0.75), a smaller batch size (2048), and a larger learning rate (4e-4):
```shell script
$ cd /path/to/SparK/pretrain
@ -61,13 +79,13 @@ $ bash ./main.sh <experiment_name> \
Once an experiment starts running, the following files would be automatically created and updated in `output_<experiment_name>`:
- `<model>_still_pretraining.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc; can be used to resume pre-training
- `<model>_1kpretrained.pth`: can be used for downstream fine-tuning
- `<model>_still_pretraining.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc; can be used to resume pretraining
- `<model>_1kpretrained.pth`: can be used for downstream finetuning
- `pretrain_log.txt`: records some important information such as:
- `git_commit_id`: git version
- `cmd`: all arguments passed to the script
It also reports the loss and remaining pre-training time at each epoch.
It also reports the loss and remaining pretraining time at each epoch.
- `stdout_backup.txt` and `stderr_backup.txt`: will save all output to stdout/stderr
@ -97,4 +115,4 @@ Here is the reason: when we do mask, we:
3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py#L21)) with larger resolutions .
So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
See [Tutorial for customizing your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-customizing-your-own-cnn-model).
See [Tutorial for pretraining your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model).

@ -22,19 +22,24 @@ def initialized():
def initialize(backend='nccl'):
global __device
if not torch.cuda.is_available():
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
return
elif 'RANK' not in os.environ:
__device = torch.empty(1).cuda().device
print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr)
return
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
local_rank = global_rank % num_gpus
torch.cuda.set_device(local_rank)
tdist.init_process_group(backend=backend)
global __rank, __local_rank, __world_size, __device, __initialized
global __rank, __local_rank, __world_size, __initialized
__local_rank = local_rank
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
__device = torch.empty(1).cuda().device
@ -81,28 +86,33 @@ def parallelize(net, syncbn=False):
def allreduce(t: torch.Tensor) -> None:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.all_reduce(cu)
t.copy_(cu.cpu())
else:
tdist.all_reduce(t)
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.all_reduce(cu)
t.copy_(cu.cpu())
else:
tdist.all_reduce(t)
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if not t.is_cuda:
t = t.cuda()
ls = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls, t)
if __initialized:
if not t.is_cuda:
t = t.cuda()
ls = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls, t)
else:
ls = [t]
if cat:
ls = torch.cat(ls, dim=0)
return ls
def broadcast(t: torch.Tensor, src_rank) -> None:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.broadcast(cu, src=src_rank)
t.copy_(cu.cpu())
else:
tdist.broadcast(t, src=src_rank)
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.broadcast(cu, src=src_rank)
t.copy_(cu.cpu())
else:
tdist.broadcast(t, src=src_rank)

@ -22,10 +22,19 @@ from models import build_sparse_encoder
from sampler import DistInfiniteBatchSampler, worker_init_fn
from spark import SparK
from utils import arg_util, misc, lamb
from utils.imagenet import build_imagenet_pretrain
from utils.imagenet import build_dataset_to_pretrain
from utils.lr_control import lr_wd_annealing, get_param_groups
class LocalDDP(torch.nn.Module):
def __init__(self, module):
super(LocalDDP, self).__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def main_pt():
args: arg_util.Args = arg_util.init_dist_and_get_args()
print(f'initial args:\n{str(args)}')
@ -33,7 +42,7 @@ def main_pt():
# build data
print(f'[build data for pre-training] ...\n')
dataset_train = build_imagenet_pretrain(args.data_path, args.input_size)
dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size)
data_loader_train = DataLoader(
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
batch_sampler=DistInfiniteBatchSampler(
@ -52,7 +61,10 @@ def main_pt():
densify_norm=args.densify_norm, sbn=args.sbn,
).to(args.device)
print(f'[PT model] model = {model_without_ddp}\n')
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
if dist.initialized() > 1:
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
else:
model = LocalDDP(model_without_ddp)
# build optimizer and lr_scheduler
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'})

@ -112,6 +112,8 @@ def init_dist_and_get_args():
misc.init_distributed_environ(exp_dir=args.exp_dir)
# update args
if not dist.initialized():
args.sbn = False
args.first_logging = True
args.device = dist.get_device()
args.batch_size_per_gpu = args.bs // dist.get_world_size()

@ -11,6 +11,7 @@ import PIL.Image as PImage
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
from torchvision.transforms import transforms
from torch.utils.data import Dataset
try:
from torchvision.transforms import InterpolationMode
@ -50,7 +51,13 @@ class ImageNetDataset(DatasetFolder):
return self.transform(self.loader(path)), target
def build_imagenet_pretrain(imagenet_folder, input_size):
def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset:
"""
You may need to modify this function to fit your own dataset.
:param dataset_path: the folder of dataset
:param input_size: the input size (image resolution)
:return: the dataset used for pretraining
"""
trans_train = transforms.Compose([
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
transforms.RandomHorizontalFlip(),
@ -58,12 +65,12 @@ def build_imagenet_pretrain(imagenet_folder, input_size):
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
imagenet_folder = os.path.abspath(imagenet_folder)
dataset_path = os.path.abspath(dataset_path)
for postfix in ('train', 'val'):
if imagenet_folder.endswith(postfix):
imagenet_folder = imagenet_folder[:-len(postfix)]
if dataset_path.endswith(postfix):
dataset_path = dataset_path[:-len(postfix)]
dataset_train = ImageNetDataset(imagenet_folder=imagenet_folder, transform=trans_train, train=True)
dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True)
print_transform(trans_train, '[pre-train]')
return dataset_train

Loading…
Cancel
Save