@ -6,7 +6,7 @@
"metadata": {},
"source": [
"# SparK: A Visualization Demo\n",
"A demo using our pre-trained SparK model (ConvNeXt-L with input size 384) to reconstruct masked images.\n",
"A demo using our pre-trained SparK model (ConvNeXt-L with input size 384, or ConvNeXt-S with 224 ) to reconstruct your masked images.\n",
"The mask is whether specified by the user or randomly generated."
]
},
@ -16,7 +16,7 @@
"metadata": {},
"source": [
"## 1. Preparation\n",
"Install dependencies and specify the device ."
"Install dependencies, specify the device, and specify the pre-trained model ."
]
},
{
@ -41,6 +41,8 @@
"\n",
"# specify the device to use\n",
"USING_GPU_IF_AVAILABLE = True\n",
"# specify the CNN to useTrue for ConvNeXt-L-384, False for ConvNeXt-S-224\n",
"USING_LARGE384_MODEL = True #\n",
"import torch\n",
"_ = torch.empty(1)\n",
"if torch.cuda.is_available() and USING_GPU_IF_AVAILABLE:\n",
@ -71,7 +73,7 @@
"def load_image(img_file: str):\n",
" img = Image.open(img_file).convert('RGB')\n",
" transform = T.Compose([\n",
" T.Resize((384, 384)),\n",
" T.Resize((384, 384) if USING_LARGE384_MODEL else (224, 224) ),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n",
" ])\n",
@ -101,13 +103,19 @@
"from spark import SparK\n",
"def build_spark():\n",
" # download and load the checkpoint\n",
" ckpt_file = 'cnxL384_withdecoder_1kpretrained_spark_style.pth'\n",
" ckpt_link = 'https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link'\n",
" if USING_LARGE384_MODEL:\n",
" model_name, input_size = 'convnext_large', 384\n",
" ckpt_file = 'cnxL384_withdecoder_1kpretrained_spark_style.pth'\n",
" ckpt_link = 'https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link'\n",
" else:\n",
" model_name, input_size = 'convnext_small', 224\n",
" ckpt_file = 'cnxS224_withdecoder_1kpretrained_spark_style.pth'\n",
" ckpt_link = 'https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link'\n",
" assert os.path.exists(ckpt_file), f'please download checkpoint {ckpt_file} from {ckpt_link}'\n",
" pretrained_state = torch.load(ckpt_file, map_location='cpu')\n",
" \n",
" # build a SparK model\n",
" enc: SparseEncoder = build_sparse_encoder('convnext_large', input_size=384)\n",
" enc: SparseEncoder = build_sparse_encoder(model_name, input_size=input_size )\n",
" spark = SparK(\n",
" sparse_encoder=enc, dense_decoder=LightDecoder(enc.downsample_raito, sbn=False),\n",
" mask_ratio=0.6, densify_norm='ln', sbn=False, hierarchy=4,\n",
@ -202,20 +210,33 @@
],
"source": [
"# specify the mask\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=torch.tensor([\n",
" [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n",
"], device=DEVICE).bool().unsqueeze(0).unsqueeze(0))"
"if USING_LARGE384_MODEL:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n",
" ], device=DEVICE).bool().reshape(1, 1, 12, 12)\n",
"else:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 1],\n",
" [0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 1, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 1, 0, 0, 0],\n",
" ], device=DEVICE).bool().reshape(1, 1, 7, 7)\n",
"\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=active_b1ff)"
]
},
{