You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
120 KiB

{
"cells": [
{
"cell_type": "markdown",
"id": "c4b4c116",
"metadata": {},
"source": [
"# SparK: A Visualization Demo\n",
"A demo using our pre-trained SparK model (ConvNeXt-L with input size 384, or ConvNeXt-S with 224) to reconstruct masked images.\n",
"The mask is whether specified by the user or randomly generated."
]
},
{
"cell_type": "markdown",
"id": "d4ff1c8e",
"metadata": {},
"source": [
"## 1. Preparation\n",
"Install dependencies, specify the device, and specify the pre-trained model."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f181fcb4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[DEVICE=cuda:0]\n"
]
}
],
"source": [
"# install dependencies\n",
"import sys\n",
"if 'google.colab' in sys.modules:\n",
" !pip3 install -r requirements.txt\n",
"\n",
"# specify the device to use\n",
"USING_GPU_IF_AVAILABLE = True\n",
2 years ago
"# specify the CNN to use\n",
"USING_LARGE384_MODEL = True # True for ConvNeXt-L-384, False for ConvNeXt-S-224\n",
"import torch\n",
"_ = torch.empty(1)\n",
"if torch.cuda.is_available() and USING_GPU_IF_AVAILABLE:\n",
" _ = _.cuda()\n",
"DEVICE = _.device\n",
"print(f'[DEVICE={DEVICE}]')\n"
]
},
{
"cell_type": "markdown",
"id": "3e451959",
"metadata": {},
"source": [
"## 2. Define utils: load data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c12b129c",
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"import torchvision.transforms as T\n",
"IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\n",
"IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n",
"def load_image(img_file: str):\n",
" img = Image.open(img_file).convert('RGB')\n",
" transform = T.Compose([\n",
" T.Resize((384, 384) if USING_LARGE384_MODEL else (224, 224)),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n",
" ])\n",
" img = transform(img)\n",
" return img.unsqueeze(0).to(DEVICE)"
]
},
{
"cell_type": "markdown",
"id": "88fb368b",
"metadata": {},
"source": [
"## 3. Define utils: build a SparK model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3ddb968b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from decoder import LightDecoder\n",
"from encoder import SparseEncoder\n",
"from models import build_sparse_encoder\n",
"from spark import SparK\n",
"def build_spark():\n",
" # download and load the checkpoint\n",
" if USING_LARGE384_MODEL:\n",
" model_name, input_size = 'convnext_large', 384\n",
" ckpt_file = 'cnxL384_withdecoder_1kpretrained_spark_style.pth'\n",
" ckpt_link = 'https://drive.google.com/file/d/1ZI9Jgtb3fKWE_vDFEly29w-1FWZSNwa0/view?usp=share_link'\n",
" else:\n",
" input_size = 224\n",
" model_name, ckpt_file, ckpt_link = {\n",
" 'ResNet50': ('resnet50', 'res50_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1STt3w3e5q9eCPZa8VzcJj1zG6p3jLeSF/view?usp=share_link'),\n",
" 'ResNet101': ('resnet101', 'res101_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1GjN48LKtlop2YQre6---7ViCWO-3C0yr/view?usp=share_link'),\n",
" 'ResNet152': ('resnet152', 'res152_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1U3Cd94j4ZHfYR2dUjWmsEWfjP6Opx4oo/view?usp=share_link'),\n",
" 'ResNet200': ('resnet200', 'res200_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/13AFSqvIr0v-2hmb4DzVza45t_lhf2CnD/view?usp=share_link'),\n",
" 'ConvNeXt-S': ('convnext_small', 'cnxS224_withdecoder_1kpretrained_spark_style.pth', 'https://drive.google.com/file/d/1bKvrE4sNq1PfzhWlQJXEPrl2kHqHRZM-/view?usp=share_link'),\n",
" }['ConvNeXt-S'] # you can choose any model here\n",
"\n",
" assert os.path.exists(ckpt_file), f'please download checkpoint {ckpt_file} from {ckpt_link}'\n",
" pretrained_state = torch.load(ckpt_file, map_location='cpu')\n",
" using_bn_in_densify = 'densify_norms.0.running_mean' in pretrained_state\n",
" \n",
" # build a SparK model\n",
" enc: SparseEncoder = build_sparse_encoder(model_name, input_size=input_size)\n",
" spark = SparK(\n",
" sparse_encoder=enc, dense_decoder=LightDecoder(enc.downsample_raito, sbn=False),\n",
" mask_ratio=0.6, densify_norm='bn' if using_bn_in_densify else 'ln', sbn=False,\n",
" ).to(DEVICE)\n",
" spark.eval(), [p.requires_grad_(False) for p in spark.parameters()]\n",
" \n",
" # load the checkpoint\n",
" missing, unexpected = spark.load_state_dict(pretrained_state, strict=False)\n",
" assert len(missing) == 0, f'load_state_dict missing keys: {missing}'\n",
" assert len(unexpected) == 0, f'load_state_dict unexpected keys: {unexpected}'\n",
" del pretrained_state\n",
" return spark\n"
]
},
{
"cell_type": "markdown",
"id": "986fb2b3",
"metadata": {},
"source": [
"## 4. Define utils: visualize"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ee8e3dc9",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
2 years ago
"def show(spark: SparK, img_file='viz_imgs/recon.png', active_b1ff: torch.BoolTensor = None):\n",
" inp_bchw = load_image(img_file)\n",
" spark.forward\n",
" inp_bchw, masked_bchw, rec_or_inp = spark(inp_bchw, active_b1ff=active_b1ff, vis=True)\n",
" # plot these three images in a row\n",
" masked_title = 'rand masked' if active_b1ff is None else 'specified masked'\n",
" for col, (title, tensor) in enumerate(zip(['input', masked_title, 'reconstructed'], [inp_bchw, masked_bchw, rec_or_inp])):\n",
" plt.subplot2grid((1, 3), (0, col))\n",
" plt.imshow(tensor[0].permute(1, 2, 0).cpu().numpy())\n",
" plt.title(title)\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "5b8b8f49",
"metadata": {},
"source": [
"## 5. Run SparK with a specified mask"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dbe10ac3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[build_sparse_encoder] model kwargs={'sparse': True, 'drop_path_rate': 0.4, 'pretrained': False, 'num_classes': 0, 'global_pool': ''}\n",
"[SparK.__init__, densify 1/4]: densify_proj(ksz=1, #para=1.18M)\n",
"[SparK.__init__, densify 2/4]: densify_proj(ksz=3, #para=2.65M)\n",
"[SparK.__init__, densify 3/4]: densify_proj(ksz=3, #para=0.66M)\n",
"[SparK.__init__, densify 4/4]: densify_proj(ksz=3, #para=0.17M)\n",
"[SparK.__init__] dims of mask_tokens=(1536, 768, 384, 192)\n"
]
}
],
"source": [
"spark = build_spark()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f405a28d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAC8CAYAAADl2K3eAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9e7RsR3Xfi39mVa1Hd+/XeUpHbyEJIQkERjyMAYGRZZmXwTYQ2TcGhAkmtm+Mx41zR5LfNYL4EUJiO8ExhPjaTkjuGCEGm9zc2GCCAiZxCBgLA5JAAiQkJHTej713d69VVfP3R9Va3fucI9jHloVAPTWOdj9Wr169Vq2a3/rOOb9TVFVZ2MIWtrCFLWxhj1kz3+oDWNjCFrawhS1sYd9aW4CBhS1sYQtb2MIe47YAAwtb2MIWtrCFPcZtAQYWtrCFLWxhC3uM2wIMLGxhC1vYwhb2GLcFGFjYwha2sIUt7DFuCzCwsIUtbGELW9hj3BZgYGELW9jCFrawx7gtwMDCFrawhS1sYY9xe0yDgd/93d9FRLj77ru/1YeysMe4Pf/5z+f5z3/+ltcefPBBXvGKV7Br1y5EhF//9V/nv/23/4aI8N/+23972L775ptvRkQetv09Evba176WpaWlv9bv+Os41wtb2Hbtkb4vH9Ng4NFkt912GzfffPMCmCyst5/7uZ/jgx/8IH//7/993vOe9/ADP/AD3+pDWtjCHjV2//33c/PNN3Prrbc+po/h4TL3rT6Ab6X9+I//ODfeeCNVVX2rD4XbbruNt7zlLTz/+c/noosu+lYfzsIeYfvQhz50ymsf+chHeNnLXsbf/bt/t3/t8Y9/POPxmLIsH8nDW9jCHnV2//3385a3vIWLLrqIpzzlKY/ZY3i47DHNDFhrqev6244iXdh3npVleYqD379/P2tra1teM8ZQ1zXGPKZv3YX9FW1jY+NbfQiPuG1ubn6rD+FRbY/pGeXknIGLLrqIl7zkJXz84x/nGc94BnVd87jHPY5/+2//7Wk/97GPfYyf/MmfZNeuXaysrPDqV7+aI0eObNlWRLj55ptP+e6LLrqI1772tf3+XvnKVwLwvd/7vYjIIlb512AnTpzgTW96ExdddBFVVbF3716uv/56Pv3pT/fbPP/5z+eJT3wif/Znf8b3fM/3MBgMuPjii3nXu951yv6m0ylvfvObufTSS6mqivPPP5+/9/f+HtPp9JRt/92/+3c84xnPYDgcsmPHDq699totbMB8zkA3vlSVf/kv/2U/HuCh49if+MQn+IEf+AFWV1cZDoc873nP47//9/9+ynF8/OMf5+lPfzp1XXPJJZfwr/7Vv9r2+evOzV/8xV/wvOc9j+FwyKWXXsrv/d7vAfDRj36UZz7zmQwGAy6//HI+/OEPb/n8Pffcw0/91E9x+eWXMxgM2LVrF6985StPCY21bctb3vIWLrvsMuq6ZteuXTznOc/hj//4j7/h8d16663s2bOH5z//+ayvrwPwta99jde97nWcddZZVFXFVVddxW//9m+f8tn77ruPl7/85YxGI/bu3cvP/dzPnfY6fjtaF3u+7bbb+LEf+zF27NjBc57zHCCNy2uuuYbBYMDOnTu58cYbuffee0/Zxyc+8Qle9KIXsWPHDkajEVdffTX//J//8y3bfOQjH+G5z30uo9GItbU1Xvayl3H77bef9ljuuusuXvva17K2tsbq6io33XTTKc76j//4j3nOc57D2toaS0tLXH755fyDf/APgHQfPP3pTwfgpptu6u+R3/3d3wW23sfXXnstw+Gw/+x25uTOjh49ys/93M/1c8Z5553Hq1/9ag4ePPhNj6E7b3/d9+XDZY/pMMHp7K677uIVr3gFP/ETP8FrXvMafvu3f5vXvva1XHPNNVx11VVbtv2Zn/kZ1tbWuPnmm/nCF77AO9/5Tu65555+wt6uXXvttfydv/N3+Bf/4l/wD/7BP+CKK64A6P8u7OGxN77xjfze7/0eP/MzP8OVV17JoUOH+PjHP87tt9/OU5/61H67I0eO8KIXvYhXvepV/OiP/ijvfe97+dt/+29TliWve93rAIgx8oM/+IN8/OMf5w1veANXXHEFn/3sZ/m1X/s1vvjFL/IHf/AH/f7e8pa3cPPNN/M93/M9vPWtb6UsSz7xiU/wkY98hO///u8/5TivvfZa3vOe9/DjP/7jXH/99bz61a/+hr/rIx/5CC984Qu55pprePOb34wxht/5nd/hBS94AX/yJ3/CM57xDAA++9nP8v3f//3s2bOHm2++Ge89b37zmznrrLO2fQ6PHDnCS17yEm688UZe+cpX8s53vpMbb7yRf//v/z1vetObeOMb38iP/diP8fa3v51XvOIV3HvvvSwvLwPwyU9+kv/xP/4HN954I+eddx53330373znO3n+85/PbbfdxnA4BJLD+JVf+RVe//rX84xnPIPjx4/zqU99ik9/+tNcf/31pz2uT37yk9xwww087WlP4wMf+ACDwYAHH3yQ7/7u70ZE+Jmf+Rn27NnDH/7hH/ITP/ETHD9+nDe96U0AjMdjrrvuOr761a/yd/7O3+Gcc87hPe95Dx/5yEe2fV6+HeyVr3wll112Gb/8y7+MqvJLv/RL/F//1//Fq171Kl7/+tdz4MAB3vGOd3Dttdfy53/+5z0r9cd//Me85CUvYd++ffzsz/4sZ599Nrfffjv/+T//Z372Z38WgA9/+MO88IUv5HGPexw333wz4/GYd7zjHTz72c/m05/+9Cmhz1e96lVcfPHF/Mqv/Aqf/vSn+a3f+i327t3L2972NgA+//nP85KXvISrr76at771rVRVxV133dU70iuuuIK3vvWt/MIv/AJveMMbeO5znwvA93zP9/TfcejQIV74whdy44038jf/5t88o3EOsL6+znOf+1xuv/12Xve61/HUpz6VgwcP8p/+03/ivvvu+6bH8Ejelw+L6WPYfud3fkcB/cpXvqKqqhdeeKEC+rGPfazfZv/+/VpVlf4f/8f/ccrnrrnmGm2apn/9n/yTf6KAfuADH+hfA/TNb37zKd994YUX6mte85r++X/8j/9RAb3lllsett+3sK22urqqP/3TP/0Nt3ne856ngP6zf/bP+tem06k+5SlP0b179/bX+z3veY8aY/RP/uRPtnz+Xe96lwL63//7f1dV1TvvvFONMfpDP/RDGkLYsm2Mccv3Pu95z9vyPnDK8d5yyy1bxkmMUS+77DK94YYbtuxvc3NTL774Yr3++uv7117+8pdrXdd6zz339K/ddtttaq3V7UwF3bn5f/6f/6d/7Y477lBAjTH6P//n/+xf/+AHP6iA/s7v/M6WYzrZ/vRP/1QB/bf/9t/2rz35yU/WF7/4xd/wWF7zmtfoaDRSVdWPf/zjurKyoi9+8Yt1Mpn02/zET/yE7tu3Tw8ePLjlszfeeKOurq72x/Prv/7rCuh73/vefpuNjQ299NJLvyPuyTe/+c0K6I/+6I/2r919991qrdVf+qVf2rLtZz/7WXXO9a977/Xiiy/WCy+8UI8cObJl2/nx1t0fhw4d6l/7zGc+o8YYffWrX33Ksbzuda/bsq8f+qEf0l27dvXPf+3Xfk0BPXDgwEP+rk9+8pOnjLHOurH6rne965T3tjsn/8Iv/IIC+v73v/+Ubbvf/lDH8Ejelw+XPabDBKezK6+8skd4AHv27OHyyy/ny1/+8inbvuENb6Aoiv753/7bfxvnHP/lv/yXR+RYF3Zmtra2xic+8Qnuv//+b7idc46f/Mmf7J+XZclP/uRPsn//fv7sz/4MgP/4H/8jV1xxBU94whM4ePBg/+8FL3gBALfccgsAf/AHf0CMkV/4hV84Jc7/cOSq3Hrrrdx555382I/9GIcOHeqPY2Njg+uuu46PfexjxBgJIfDBD36Ql7/85VxwwQX956+44gpuuOGGbX/f0tISN954Y//88ssvZ21tjSuuuIJnPvOZ/evd4/n7ZjAY9I/btuXQoUNceumlrK2tbQnVrK2t8fnPf54777zzmx7PLbfcwg033MB1113H+9///j4ZWFV
"text/plain": [
"<Figure size 640x480 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# specify the mask\n",
"if USING_LARGE384_MODEL:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n",
" [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n",
" ], device=DEVICE).bool().reshape(1, 1, 12, 12)\n",
"else:\n",
" active_b1ff = torch.tensor([\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 0, 1],\n",
" [0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 1, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 1, 0],\n",
" [0, 0, 0, 1, 0, 0, 0],\n",
" ], device=DEVICE).bool().reshape(1, 1, 7, 7)\n",
"\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=active_b1ff)"
]
},
{
"cell_type": "markdown",
"id": "6e80826f",
"metadata": {},
"source": [
"## 6. Run SparK with a random mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5c22039",
"metadata": {},
"outputs": [],
"source": [
"# use a random mask (don't specify the mask)\n",
"show(spark, 'viz_imgs/recon.png', active_b1ff=None)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}