[upd] add small model for more lightweight visualization

main
keyu-tian 2 years ago
parent bea2b8281b
commit 48401f6ad9
  1. 3
      README.md
  2. 33
      pretrain/viz_reconstruction.ipynb

@ -101,7 +101,7 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
</details>
## SparK Pretrained weights
## SparK pretrained weights (self-supervised)
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
@ -118,6 +118,7 @@ We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that show
| ConvNeXt-B | 224 | 84.8 | 89M | 15.4G | [convnextB_1kpretrained_official_style.pth](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link) |
| ConvNeXt-L | 224 | 85.4 | 198M | 34.4G | [convnextL_1kpretrained_official_style.pth](https://drive.google.com/file/d/1qfYzGUpYBzuA88_kXkVl4KNUwfutMVfw/view?usp=share_link) |
| ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | [convnextL_384_1kpretrained_official_style.pth](https://drive.google.com/file/d/1YgWNXJjI89l35P4ksAmBNWZ2JZCpj9n4/view?usp=share_link) |
| S-with-decoder | 224 | 84.1 | 50M | 8.7G | [cnxS224_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link) |
| L-with-decoder | 384 | 86.0 | 198M | 101.0G | [cnxL384_withdecoder_1kpretrained_spark_style.pth](https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link) |

@ -6,7 +6,7 @@
"metadata": {},
"source": [
"# SparK: A Visualization Demo\n",
"A demo using our pre-trained SparK model (ConvNeXt-L with input size 384) to reconstruct masked images.\n",
"A demo using our pre-trained SparK model (ConvNeXt-L with input size 384, or ConvNeXt-S with 224) to reconstruct your masked images.\n",
"The mask is whether specified by the user or randomly generated."
]
},
@ -16,7 +16,7 @@
"metadata": {},
"source": [
"## 1. Preparation\n",
"Install dependencies and specify the device."
"Install dependencies, specify the device, and specify the pre-trained model."
]
},
{
@ -41,6 +41,8 @@
"\n",
"# specify the device to use\n",
"USING_GPU_IF_AVAILABLE = True\n",
"# specify the CNN to useTrue for ConvNeXt-L-384, False for ConvNeXt-S-224\n",
"USING_LARGE384_MODEL = True #\n",
"import torch\n",
"_ = torch.empty(1)\n",
"if torch.cuda.is_available() and USING_GPU_IF_AVAILABLE:\n",
@ -71,7 +73,7 @@
"def load_image(img_file: str):\n",
" img = Image.open(img_file).convert('RGB')\n",
" transform = T.Compose([\n",
" T.Resize((384, 384)),\n",
" T.Resize((384, 384) if USING_LARGE384_MODEL else (224, 224)),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n",
" ])\n",
@ -101,13 +103,19 @@
"from spark import SparK\n",
"def build_spark():\n",
" # download and load the checkpoint\n",
" if USING_LARGE384_MODEL:\n",
" model_name, input_size = 'convnext_large', 384\n",
" ckpt_file = 'cnxL384_withdecoder_1kpretrained_spark_style.pth'\n",
" ckpt_link = 'https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link'\n",
" else:\n",
" model_name, input_size = 'convnext_small', 224\n",
" ckpt_file = 'cnxS224_withdecoder_1kpretrained_spark_style.pth'\n",
" ckpt_link = 'https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link'\n",
" assert os.path.exists(ckpt_file), f'please download checkpoint {ckpt_file} from {ckpt_link}'\n",
" pretrained_state = torch.load(ckpt_file, map_location='cpu')\n",
" \n",
" # build a SparK model\n",
" enc: SparseEncoder = build_sparse_encoder('convnext_large', input_size=384)\n",
" enc: SparseEncoder = build_sparse_encoder(model_name, input_size=input_size)\n",
" spark = SparK(\n",
" sparse_encoder=enc, dense_decoder=LightDecoder(enc.downsample_raito, sbn=False),\n",
" mask_ratio=0.6, densify_norm='ln', sbn=False, hierarchy=4,\n",
@ -202,7 +210,8 @@
],
"source": [
"# specify the mask\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=torch.tensor([\n",
"if USING_LARGE384_MODEL:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],\n",
@ -215,7 +224,19 @@
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n",
"], device=DEVICE).bool().unsqueeze(0).unsqueeze(0))"
" ], device=DEVICE).bool().reshape(1, 1, 12, 12)\n",
"else:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 1],\n",
" [0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 1, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 1, 0, 0, 0],\n",
" ], device=DEVICE).bool().reshape(1, 1, 7, 7)\n",
"\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=active_b1ff)"
]
},
{

Loading…
Cancel
Save