[upd] add notes for 384 training

main
keyu tian 2 years ago
parent 93cb3e4a92
commit c4d28a1c6a
  1. 24
      README.md
  2. 5
      downstream_imagenet/README.md
  3. 26
      downstream_imagenet/arg.py
  4. 14
      downstream_imagenet/data.py
  5. 2
      downstream_imagenet/util.py
  6. 29
      pretrain/README.md
  7. 2
      pretrain/main.py
  8. 2
      pretrain/models/__init__.py
  9. 38
      pretrain/spark.py
  10. BIN
      pretrain/viz1.png

@ -82,16 +82,20 @@ https://user-images.githubusercontent.com/6366788/213662770-5f814de0-cbe8-48d9-8
## ImageNet-1k results and pre-trained networks weights
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
| arch. | acc@1 | #params | flops | model |
|:---:|:---:|:---:|:---:|:---:|
| ResNet50 | 80.6 | 26M | 4.1G | [drive](https://drive.google.com/file/d/1H8605HbxGvrsu4x4rIoNr-Wkd7JkxFPQ/view?usp=share_link) |
| ResNet101 | 82.2 | 45M | 7.9G | [drive](https://drive.google.com/file/d/1ZwTztjU-_rfvOVfLoce9SMw2Fx0DQfoO/view?usp=share_link) |
| ResNet152 | 82.7 | 60M | 11.6G | [drive](https://drive.google.com/file/d/1FOVuECnzQAI-OzE-hnrqW7tVpg8kTziM/view?usp=share_link) |
| ResNet200 | 83.1 | 65M | 15.1G | [drive](https://drive.google.com/file/d/1_Q4e30qqhjchrdyW3fT6P98Ga-WnQ57s/view?usp=share_link) |
| ConvNeXt-S | 84.1 | 50M | 8.7G | [drive](https://drive.google.com/file/d/1Ah6lgDY5YDNXoXHQHklKKMbEd08RYivN/view?usp=share_link) |
| ConvNeXt-B | 84.8 | 89M | 15.4G | [drive](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link) |
| ConvNeXt-L | 85.4 | 198M | 34.4G | [drive](https://drive.google.com/file/d/1qfYzGUpYBzuA88_kXkVl4KNUwfutMVfw/view?usp=share_link) |
`reso.`: the image resolution; `acc@1`: IN1k fine-tuned acc (top-1)
| arch. | reso. | acc@1 | #params | flops | weights on google drive |
|:--------------:|:-----:|:-----:|:-------:|:------:|:------------------------------------------------------------------------------------------------------------------------------------------|
| ResNet50 | 224 | 80.6 | 26M | 4.1G | [resnet50_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1H8605HbxGvrsu4x4rIoNr-Wkd7JkxFPQ/view?usp=share_link) |
| ResNet101 | 224 | 82.2 | 45M | 7.9G | [resnet101_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1ZwTztjU-_rfvOVfLoce9SMw2Fx0DQfoO/view?usp=share_link) |
| ResNet152 | 224 | 82.7 | 60M | 11.6G | [resnet152_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1FOVuECnzQAI-OzE-hnrqW7tVpg8kTziM/view?usp=share_link) |
| ResNet200 | 224 | 83.1 | 65M | 15.1G | [resnet200_1kpretrained_timm_style.pth](https://drive.google.com/file/d/1_Q4e30qqhjchrdyW3fT6P98Ga-WnQ57s/view?usp=share_link) |
| ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | [convnextS_1kpretrained_official_style.pth](https://drive.google.com/file/d/1Ah6lgDY5YDNXoXHQHklKKMbEd08RYivN/view?usp=share_link) |
| ConvNeXt-B | 224 | 84.8 | 89M | 15.4G | [convnextB_1kpretrained_official_style.pth](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link) |
| ConvNeXt-L | 224 | 85.4 | 198M | 34.4G | [convnextL_1kpretrained_official_style.pth](https://drive.google.com/file/d/1qfYzGUpYBzuA88_kXkVl4KNUwfutMVfw/view?usp=share_link) |
| ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | [convnextL_384_1kpretrained_official_style.pth](https://drive.google.com/file/d/1YgWNXJjI89l35P4ksAmBNWZ2JZCpj9n4/view?usp=share_link) |
| L-with-decoder | 384 | 86.0 | 198M | 101.0G | [cnxL384_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link) |

@ -13,9 +13,8 @@ See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to pre
## Fine-tuning on ImageNet-1k from pre-trained weights
Run [downstream_imagenet/main.sh](https://github.com/keyu-tian/SparK/blob/main/downstream_imagenet/main.sh).
It is **required** to specify ImageNet data folder, model name, and checkpoint file path to run fine-tuning.
All the other arguments have their default values, listed in [downstream_imagenet/arg.py#L13](https://github.com/keyu-tian/SparK/blob/main/downstream_imagenet/arg.py#L13).
It is **required** to specify your experiment_name `<experiment_name>` ImageNet data folder `--data_path`, model name `--model`, and checkpoint file path `--resume_from` to run fine-tuning.
All the other configurations have their default values, listed in [downstream_imagenet/arg.py#L13](https://github.com/keyu-tian/SparK/blob/main/downstream_imagenet/arg.py#L13).
You can override any defaults by passing key-word arguments (like `--bs=2048`) to `main.sh`.

@ -12,14 +12,15 @@ from tap import Tap
HP_DEFAULT_NAMES = ['bs', 'ep', 'wp_ep', 'opt', 'base_lr', 'lr_scale', 'wd', 'mixup', 'rep_aug', 'drop_path', 'ema']
HP_DEFAULT_VALUES = {
'convnext_small': (4096, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999),
'convnext_base': (4096, 400, 20, 'adam', 0.0001, 0.7, 0.01, 0.8, 3, 0.4, 0.9999),
'convnext_large': (4096, 200, 10, 'adam', 0.0001, 0.7, 0.02, 0.8, 3, 0.5, 0.9999),
'convnext_small': (4096, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999),
'convnext_base': (4096, 400, 20, 'adam', 0.0001, 0.7, 0.01, 0.8, 3, 0.4, 0.9999),
'convnext_large': (4096, 200, 10, 'adam', 0.0001, 0.7, 0.02, 0.8, 3, 0.5, 0.9999),
'convnext_large_384': (1024, 200, 20, 'adam', 0.00006, 0.7, 0.01, 0.8, 3, 0.5, 0.99995),
'resnet50': (4096, 300, 5, 'lamb', 0.002, 0.7, 0.02, 0.1, 0, 0.05, 0.9999),
'resnet101': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
'resnet152': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
'resnet200': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
'resnet50': (4096, 300, 5, 'lamb', 0.002, 0.7, 0.02, 0.1, 0, 0.05, 0.9999),
'resnet101': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
'resnet152': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
'resnet200': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
}
@ -116,11 +117,14 @@ def get_args(world_size, global_rank, local_rank, device) -> FineTuneArgs:
try: os.makedirs(args.tb_lg_dir, exist_ok=True)
except: pass
# update args.bs, args.ep, etc. (if their values are like 0.0 or 0 or '')
for k, v in zip(HP_DEFAULT_NAMES, HP_DEFAULT_VALUES[args.model]):
if not bool(getattr(args, k)):
# fill in args.bs, args.ep, etc. with their default values (if their values are not explicitly specified, i.e., if bool(they) == False)
if args.model == 'convnext_large' and args.img_size == 384:
default_values = HP_DEFAULT_VALUES['convnext_large_384']
else:
default_values = HP_DEFAULT_VALUES[args.model]
for k, v in zip(HP_DEFAULT_NAMES, default_values):
if bool(getattr(args, k)) == False:
setattr(args, k, v)
args.ema = args.ema or 0.9999
# update other runtime args
args.world_size, args.global_rank, args.local_rank, args.device = world_size, global_rank, local_rank, device

@ -39,17 +39,17 @@ def create_classification_dataset(data_path, img_size, rep_aug, workers, batch_s
auto_augment='v0', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1,
mean=mean, std=std,
)
for i, t in enumerate(trans_train.transforms):
if isinstance(t, (TorchAutoAugment, TimmAutoAugment)):
trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation)
break
if img_size >= 384:
if img_size < 384:
for i, t in enumerate(trans_train.transforms):
if isinstance(t, (TorchAutoAugment, TimmAutoAugment)):
trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation)
break
trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std)
else:
trans_val = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=interpolation),
transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
])
else:
trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std)
print_transform(trans_train, '[train]')
print_transform(trans_val, '[val]')

@ -100,7 +100,7 @@ def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer):
# return 0, '[no performance_desc]'
print(f'[try to resume from file `{resume_from}`]')
checkpoint = torch.load(resume_from, map_location='cpu')
assert checkpoint.get('is_pretrain', False) == False, 'please do not use `PT-xxxx-.pth`; it is only for pretraining'
assert checkpoint.get('is_pretrain', False) == False, 'Please do not use `*_still_pretraining.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained.pth` or `*_1kfinetuned*.pth` instead.'
ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]')
missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)

@ -5,14 +5,14 @@ See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to pre
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
## Pre-training on ImageNet-1k from scratch
## Pre-training Any Model on ImageNet-1k (224x224)
For pre-training, run [main.sh](https://github.com/keyu-tian/SparK/blob/main/pretrain/main.sh) with bash.
It is **required** to specify the ImageNet data folder (`--data_path`), the model name (`--model`), and your experiment name (the first argument of `main.sh`) when running the script.
Note that it is **required** to specify the ImageNet data folder (`--data_path`) and model name (`--model`) to run pre-training.
For **all** other configurations/hyperparameters, their names and **default values** can be found in [utils/arg_util.py line24-47](https://github.com/keyu-tian/SparK/blob/main/pretrain/utils/arg_util.py#L24).
If you do not specify them like `--ep=800`, those default configurations would be used.
We use the **same** pre-training configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts).
Their names and **default values** can be found in [utils/arg_util.py line24-47](https://github.com/keyu-tian/SparK/blob/main/pretrain/utils/arg_util.py#L24).
These default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
Here is an example command pre-training a ResNet50 on single machine with 8 GPUs:
```shell script
@ -20,16 +20,29 @@ $ cd /path/to/SparK/pretrain
$ bash ./main.sh <experiment_name> \
--num_nodes=1 --ngpu_per_node=8 \
--data_path=/path/to/imagenet \
--model=resnet50
--model=resnet50 --bs=512
```
For multiple machines, change the `num_nodes` to your count and plus these args:
For multiple machines, change the `--num_nodes` to your count, and plus these args:
```shell script
--node_rank=<rank_starts_from_0> --master_address=<some_address> --master_port=<some_port>
```
Note that the first argument `<experiment_name>` is the name of your experiment, which would be used to create an output directory named `output_<experiment_name>`.
Note the `<experiment_name>` is the name of your experiment, which would be used to create an output directory named `output_<experiment_name>`.
## Pre-training ConvNeXt-Large on ImageNet-1k (384x384)
For pre-training with resolution 384, we use a larger mask ratio (0.75), a smaller batch size (2048), and a larger learning rate (4e-4):
```shell script
$ cd /path/to/SparK/pretrain
$ bash ./main.sh <experiment_name> \
--num_nodes=8 --ngpu_per_node=8 --node_rank=... --master_address=... --master_port=... \
--data_path=/path/to/imagenet \
--model=convnext_large --input_size=384 --mask=0.75 \
--bs=2048 --base_lr=4e-4
```
## Logging

@ -136,7 +136,7 @@ def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itr
# forward and backward
inp = inp.to(args.device, non_blocking=True)
SparK.forward
_, _, loss = model(inp)
loss = model(inp, active_b1ff=None, vis=False)
optimizer.zero_grad()
loss.backward()
loss = loss.item()

@ -63,5 +63,5 @@ def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0
kwargs, params, flops, downsample_raito, fea_dim = pre_train_d[name]
if drop_path_rate != 0:
kwargs['drop_path_rate'] = drop_path_rate
print(f'[sparse_cnn] model kwargs={kwargs}')
print(f'[build_sparse_encoder] model kwargs={kwargs}')
return SparseEncoder(create_model(name, **kwargs), input_size=input_size, downsample_raito=downsample_raito, encoder_fea_dim=fea_dim, sbn=sbn, verbose=verbose)

@ -52,11 +52,11 @@ class SparK(nn.Module):
if i == 0 and e_width == d_width:
densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
print(f'[mid, py={self.hierarchy}][densify {i} proj]: use nn.Identity()')
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj')
else:
kernel_size = 1 if i <= 0 else 3
densify_proj = nn.Conv2d(e_width, d_width, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, bias=True)
print(f'[mid, py={self.hierarchy}][densify {i} proj]: k={kernel_size}, #para = {sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}')
print(f'[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)')
self.densify_projs.append(densify_proj)
p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
@ -65,7 +65,7 @@ class SparK(nn.Module):
e_width //= 2
d_width //= 2
print(f'[mid, py={self.hierarchy}][mask_tokens]: {tuple(p.numel() for p in self.mask_tokens)}')
print(f'[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}')
m = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1)
s = torch.tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1)
@ -80,9 +80,9 @@ class SparK(nn.Module):
idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
return torch.zeros(B, f * f, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, f, f)
def forward(self, inp_bchw: torch.Tensor, active_b1ff=None):
def forward(self, inp_bchw: torch.Tensor, active_b1ff=None, vis=False):
# step1. Mask
if active_b1ff is None:
if active_b1ff is None: # rand mask
active_b1ff: torch.BoolTensor = self.mask(inp_bchw.shape[0], inp_bchw.device) # (B, 1, f, f)
encoder._cur_active = active_b1ff # (B, 1, f, f)
active_b1hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W)
@ -106,20 +106,22 @@ class SparK(nn.Module):
# step4. Decode and reconstruct
rec_bchw = self.dense_decoder(to_dec)
recon_loss = self.reconstruction_loss(inp_bchw, rec_bchw, active_b1ff)
return active_b1hw, rec_bchw, recon_loss
def reconstruction_loss(self, inp, rec, active): # active: (B, 1, f, f)
inp, rec = self.patchify(inp), self.patchify(rec) # inp and rec: (B, L = f*f, N = C*downsample_raito**2)
inp, rec = self.patchify(inp_bchw), self.patchify(rec_bchw) # inp and rec: (B, L = f*f, N = C*downsample_raito**2)
mean = inp.mean(dim=-1, keepdim=True)
var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5
inp = (inp - mean) / var
loss_spa = (rec - inp) ** 2
l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L)
non_active = active_b1ff.logical_not().int().view(active_b1ff.shape[0], -1) # (B, 1, f, f) => (B, L)
recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) # loss only on masked (non-active) patches
loss_spa = loss_spa.mean(dim=2, keepdim=False) # (B, L, C) => (B, L)
non_active = active.logical_not().int().view(active.shape[0], -1) # (B, 1, f, f) => (B, L)
return loss_spa.mul_(non_active).sum() / (non_active.sum() + 1e-8) # only on removed patches
if vis:
masked_bchw = inp_bchw * active_b1hw
rec_bchw = self.unpatchify(rec * var + mean)
rec_or_inp = torch.where(active_b1hw, inp_bchw, rec_bchw)
return [self.denorm_for_vis(i) for i in (inp_bchw, masked_bchw, rec_or_inp)]
else:
return recon_loss
def patchify(self, bchw):
p = self.downsample_raito
@ -150,13 +152,13 @@ class SparK(nn.Module):
return {
# self
'mask_ratio': self.mask_ratio,
'en_de_norm': self.densify_norm_str,
'densify_norm_str': self.densify_norm_str,
'sbn': self.sbn, 'hierarchy': self.hierarchy,
# enc
'input_size': self.sparse_encoder.input_size,
'sparse_encoder.input_size': self.sparse_encoder.input_size,
# dec
'dec_fea_dim': self.dense_decoder.width,
'dense_decoder.width': self.dense_decoder.width,
}
def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False):

Binary file not shown.

After

Width:  |  Height:  |  Size: 262 KiB

Loading…
Cancel
Save