From 36aa8210b2580b793cd514fd81096a99519ac6e6 Mon Sep 17 00:00:00 2001 From: keyu-tian Date: Wed, 12 Apr 2023 14:27:40 +0800 Subject: [PATCH] [upd] add support for 1GPU debug and customized dataset --- README.md | 3 ++- pretrain/README.md | 44 +++++++++++++++++++++++++----------- pretrain/dist.py | 46 +++++++++++++++++++++++--------------- pretrain/main.py | 18 ++++++++++++--- pretrain/utils/arg_util.py | 2 ++ pretrain/utils/imagenet.py | 17 +++++++++----- 6 files changed, 90 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index cf9c0d1..7020f1d 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,8 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show catalog - [x] Pretraining code -- [x] Pretraining toturial for custom CNN model ([pretrain/models/custom.py](pretrain/models/custom.py)) +- [x] Pretraining toturial for customized CNN model ([Tutorial for pretraining your own CNN model](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model)) +- [x] Pretraining toturial for customized dataset ([Tutorial for pretraining your own dataset](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-dataset)) - [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb)) - [x] Finetuning code - [ ] Weights & visualization playground in `huggingface` diff --git a/pretrain/README.md b/pretrain/README.md index f12b6dd..52bd0b8 100644 --- a/pretrain/README.md +++ b/pretrain/README.md @@ -1,11 +1,11 @@ -## Preparation for ImageNet-1k pre-training +## Preparation for ImageNet-1k pretraining See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset. **Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).** -## Tutorial for customizing your own CNN model +## Tutorial for pretraining your own CNN model See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed to do is: @@ -15,19 +15,37 @@ See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed - define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py line54](/pretrain/models/custom.py#L53-L54). - add default kwargs of `your_convnet(...)` in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34). -Then you can use `--model=your_convnet` in the pre-training script. +Then you can use `--model=your_convnet` in the pretraining script. -## Pre-training Any Model on ImageNet-1k (224x224) +## Tutorial for pretraining your own dataset -For pre-training, run [/pretrain/main.sh](/pretrain/main.sh) with bash. +Replace the function `build_dataset_to_pretrain` in [line54-75 of /pretrain/utils/imagenet.py](/pretrain/utils/imagenet.py#L54-L75) to yours. +This function should return a `Dataset` object. You may use args like `args.data_path` and `args.input_size` to help build your dataset. And when runing experiment with `main.sh` you can use `--data_path=... --input_size=...` to specify them. +Note the batch size `--bs` is the total batch size of all GPU, which may also need to be tuned. + + +## Debug on 1 GPU (without DistributedDataParallel) + +Use a small batch size `--bs=32` for avoiding OOM. + +```shell script +python3 main.py --exp_name=debug --data_path=/path/to/imagenet --model=resnet50 --bs=32 +``` + + +## Pretraining Any Model on ImageNet-1k (224x224) + +For pretraining, run [/pretrain/main.sh](/pretrain/main.sh) with bash. It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script. -We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts). +We use the **same** pretraining configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts). Their names and **default values** can be found in [/pretrain/utils/arg_util.py line23-44](/pretrain/utils/arg_util.py#L23-L44). These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`. -Here is an example command pre-training a ResNet50 on single machine with 8 GPUs: +**Note: the batch size `--bs` is the total batch size of all GPU, and the learning rate `--lr` is the base learning rate. The actual learning rate would be `lr * bs / 256`, as in [/pretrain/utils/arg_util.py line131](/pretrain/utils/arg_util.py#L131).** + +Here is an example command pretraining a ResNet50 on single machine with 8 GPUs (we use DistributedDataParallel): ```shell script $ cd /path/to/SparK/pretrain $ bash ./main.sh \ @@ -44,9 +62,9 @@ For multiple machines, change the `--num_nodes` to your count, and plus these ar Note the `` is the name of your experiment, which would be used to create an output directory named `output_`. -## Pre-training ConvNeXt-Large on ImageNet-1k (384x384) +## Pretraining ConvNeXt-Large on ImageNet-1k (384x384) -For pre-training with resolution 384, we use a larger mask ratio (0.75), a smaller batch size (2048), and a larger learning rate (4e-4): +For pretraining with resolution 384, we use a larger mask ratio (0.75), a smaller batch size (2048), and a larger learning rate (4e-4): ```shell script $ cd /path/to/SparK/pretrain @@ -61,13 +79,13 @@ $ bash ./main.sh \ Once an experiment starts running, the following files would be automatically created and updated in `output_`: -- `_still_pretraining.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc; can be used to resume pre-training -- `_1kpretrained.pth`: can be used for downstream fine-tuning +- `_still_pretraining.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc; can be used to resume pretraining +- `_1kpretrained.pth`: can be used for downstream finetuning - `pretrain_log.txt`: records some important information such as: - `git_commit_id`: git version - `cmd`: all arguments passed to the script - It also reports the loss and remaining pre-training time at each epoch. + It also reports the loss and remaining pretraining time at each epoch. - `stdout_backup.txt` and `stderr_backup.txt`: will save all output to stdout/stderr @@ -97,4 +115,4 @@ Here is the reason: when we do mask, we: 3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py#L21)) with larger resolutions . So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8. -See [Tutorial for customizing your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-customizing-your-own-cnn-model). +See [Tutorial for pretraining your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model). diff --git a/pretrain/dist.py b/pretrain/dist.py index bae9379..fc05ad1 100644 --- a/pretrain/dist.py +++ b/pretrain/dist.py @@ -22,19 +22,24 @@ def initialized(): def initialize(backend='nccl'): + global __device if not torch.cuda.is_available(): print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) return + elif 'RANK' not in os.environ: + __device = torch.empty(1).cuda().device + print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr) + return # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') - global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count() + global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() local_rank = global_rank % num_gpus torch.cuda.set_device(local_rank) tdist.init_process_group(backend=backend) - global __rank, __local_rank, __world_size, __device, __initialized + global __rank, __local_rank, __world_size, __initialized __local_rank = local_rank __rank, __world_size = tdist.get_rank(), tdist.get_world_size() __device = torch.empty(1).cuda().device @@ -81,28 +86,33 @@ def parallelize(net, syncbn=False): def allreduce(t: torch.Tensor) -> None: - if not t.is_cuda: - cu = t.detach().cuda() - tdist.all_reduce(cu) - t.copy_(cu.cpu()) - else: - tdist.all_reduce(t) + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + tdist.all_reduce(cu) + t.copy_(cu.cpu()) + else: + tdist.all_reduce(t) def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: - if not t.is_cuda: - t = t.cuda() - ls = [torch.empty_like(t) for _ in range(__world_size)] - tdist.all_gather(ls, t) + if __initialized: + if not t.is_cuda: + t = t.cuda() + ls = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls, t) + else: + ls = [t] if cat: ls = torch.cat(ls, dim=0) return ls def broadcast(t: torch.Tensor, src_rank) -> None: - if not t.is_cuda: - cu = t.detach().cuda() - tdist.broadcast(cu, src=src_rank) - t.copy_(cu.cpu()) - else: - tdist.broadcast(t, src=src_rank) + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + tdist.broadcast(cu, src=src_rank) + t.copy_(cu.cpu()) + else: + tdist.broadcast(t, src=src_rank) diff --git a/pretrain/main.py b/pretrain/main.py index 6fbcbdc..a7c9fdd 100644 --- a/pretrain/main.py +++ b/pretrain/main.py @@ -22,10 +22,19 @@ from models import build_sparse_encoder from sampler import DistInfiniteBatchSampler, worker_init_fn from spark import SparK from utils import arg_util, misc, lamb -from utils.imagenet import build_imagenet_pretrain +from utils.imagenet import build_dataset_to_pretrain from utils.lr_control import lr_wd_annealing, get_param_groups +class LocalDDP(torch.nn.Module): + def __init__(self, module): + super(LocalDDP, self).__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def main_pt(): args: arg_util.Args = arg_util.init_dist_and_get_args() print(f'initial args:\n{str(args)}') @@ -33,7 +42,7 @@ def main_pt(): # build data print(f'[build data for pre-training] ...\n') - dataset_train = build_imagenet_pretrain(args.data_path, args.input_size) + dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size) data_loader_train = DataLoader( dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True, batch_sampler=DistInfiniteBatchSampler( @@ -52,7 +61,10 @@ def main_pt(): densify_norm=args.densify_norm, sbn=args.sbn, ).to(args.device) print(f'[PT model] model = {model_without_ddp}\n') - model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) + if dist.initialized() > 1: + model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) + else: + model = LocalDDP(model_without_ddp) # build optimizer and lr_scheduler param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}) diff --git a/pretrain/utils/arg_util.py b/pretrain/utils/arg_util.py index 46a9865..95f80d0 100644 --- a/pretrain/utils/arg_util.py +++ b/pretrain/utils/arg_util.py @@ -112,6 +112,8 @@ def init_dist_and_get_args(): misc.init_distributed_environ(exp_dir=args.exp_dir) # update args + if not dist.initialized(): + args.sbn = False args.first_logging = True args.device = dist.get_device() args.batch_size_per_gpu = args.bs // dist.get_world_size() diff --git a/pretrain/utils/imagenet.py b/pretrain/utils/imagenet.py index 9473355..3249281 100644 --- a/pretrain/utils/imagenet.py +++ b/pretrain/utils/imagenet.py @@ -11,6 +11,7 @@ import PIL.Image as PImage from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS from torchvision.transforms import transforms +from torch.utils.data import Dataset try: from torchvision.transforms import InterpolationMode @@ -50,7 +51,13 @@ class ImageNetDataset(DatasetFolder): return self.transform(self.loader(path)), target -def build_imagenet_pretrain(imagenet_folder, input_size): +def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset: + """ + You may need to modify this function to fit your own dataset. + :param dataset_path: the folder of dataset + :param input_size: the input size (image resolution) + :return: the dataset used for pretraining + """ trans_train = transforms.Compose([ transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), transforms.RandomHorizontalFlip(), @@ -58,12 +65,12 @@ def build_imagenet_pretrain(imagenet_folder, input_size): transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ]) - imagenet_folder = os.path.abspath(imagenet_folder) + dataset_path = os.path.abspath(dataset_path) for postfix in ('train', 'val'): - if imagenet_folder.endswith(postfix): - imagenet_folder = imagenet_folder[:-len(postfix)] + if dataset_path.endswith(postfix): + dataset_path = dataset_path[:-len(postfix)] - dataset_train = ImageNetDataset(imagenet_folder=imagenet_folder, transform=trans_train, train=True) + dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True) print_transform(trans_train, '[pre-train]') return dataset_train