You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
26 KiB

2 years ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SparK: Visualization for \"mask pattern vanishing\" issue\n",
"Load binary images and visualize the convoluted images."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# load images\n",
"\n",
"import torch\n",
"from torchvision.io import read_image\n",
"\n",
"N = 3\n",
"raw_inp = []\n",
"for i in range(1, N+1):\n",
" chw = read_image(f'viz_imgs/spconv{i}.png')[:1] # only take the first channel\n",
" BLACK = 0\n",
" # binarize: black represents active (1), white represents masked (0)\n",
" active = torch.where(chw == BLACK, torch.ones_like(chw, dtype=torch.float) * 255., torch.zeros_like(chw, dtype=torch.float))\n",
" raw_inp.append(active)\n",
"raw_inp = torch.stack(raw_inp, dim=0)\n",
"active = raw_inp.bool() # active: 1, masked: 0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# apply dense conv and sparse conv\n",
"\n",
"import encoder\n",
"dense_conv = torch.nn.Conv2d(1, 1, kernel_size=(3, 3), stride=1, padding=1, bias=False)\n",
"sparse_conv = encoder.SparseConv2d(1, 1, kernel_size=(3, 3), stride=1, padding=1, bias=False)\n",
"\n",
"dense_conv.weight.data.fill_(1/dense_conv.weight.numel()), dense_conv.weight.requires_grad_(False)\n",
"sparse_conv.weight.data.fill_(1/sparse_conv.weight.numel()), sparse_conv.weight.requires_grad_(False)\n",
"\n",
"# after the first dense conv\n",
"conv1 = (dense_conv(raw_inp) > 0) * 255. # binarize\n",
"# after the second dense conv\n",
"conv2 = (dense_conv(conv1) > 0) * 255. # binarize\n",
"\n",
"# after the sparse conv\n",
"encoder._cur_active = active\n",
"spconv = (sparse_conv(raw_inp) > 0) * 255. # binarize"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA7YAAAL+CAYAAACZhogtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA/6klEQVR4nO3deZhcVZ0//k93OktDNghLFkKCgAIhYMKaECDbMPAAsukADmtQkcUFle/DgI+II6KMjjIoLixRMDiOZAaHMcmwJcqACISQBAgKIZAYyQaYANnT9/cHv5R9e63eUnW6X6/n6edJVc49depW3VvnfZdzKrIsywIAAAASVVnqBgAAAEBbCLYAAAAkTbAFAAAgaYItAAAASRNsAQAASJpgCwAAQNIEWwAAAJIm2AIAAJA0wRYAAICkVZW6AZTenDlzYsKECYXHs2fPjvHjx5euQQBAu8qyLBYvXhwvvPBCLFu2LNatWxe9evWKAQMGxMiRI+PQQw+Nbt26lbqZAK0m2AIAdEJr166N+++/P37zm9/E7NmzY82aNY2W7du3b1x44YXxhS98IYYPH77jGgnQTlyKDECjli1bFtOnT49rrrkmJk6cGP369YuKiorC31e/+tVSNxFowOzZs2OPPfaIiy66KH71q181GWojItatWxe33nprjBw5MqZOnbqDWgnQfpyxBSBnw4YNcfbZZ8fTTz8dK1asKHVzgFZYu3ZtbN68Ofdcz54946ijjooDDjgg9thjj9i0aVM8//zzMXv27Ni4cWNERLz77rsxZcqU2LBhQ1x++eWlaDpAqwi2xPjx4yPLslI3AygTmzZtigceeKDUzQDaQUVFRUyaNCkuvfTSOOWUU6JXr171yqxYsSI+85nPxH333Vd47jOf+Uwcc8wxceihh+7I5gK0mkuRASjK0KFD4+ijjy51M4AiVFRUxEknnRTz5s2Lhx56KD760Y82GGojIgYOHBj/8R//Eeeee27huZqamrj22mt3VHMB2qwic6oOgFr++te/xgc/+ME44ogj4ogjjojDDz88jjjiiNhzzz3rjaJ+/fXXu88WytDWrVujqqplF+a99dZbMXz48HjnnXciIqJ79+6xZs2a6Nu3b0c0EaBduRQZgJz+/fvHqlWrSt0MoA1aGmojInbdddc44YQTYvr06RERsWXLlpg3b14cf/zx7d08gHaXdLD905/+FPPnz4833ngj3n333Rg+fHh8/OMfb7Ds6tWr4/nnn49XXnkl3n777di2bVvssssuMXjw4BgzZkzsvvvuO7j1ndO8efPihRdeiOXLl0evXr1i6NChMX78+Nh1113bpf6VK1fGE088EcuXL4/169fHoEGD4sgjj4wPfehD7VI/ADveW2+9FU888US88cYb8eabb0ZlZWUMGDAgDjzwwBg1alRUV1e3qL4//elPMXfu3Fi1alVs2LAhdttttxg2bFiMGzeuxXU15fXXX4+nnnoqli1bFjU1NbHnnnvGuHHjYp999mm319jR9ttvv9zjlStXlqgllNLWrVtjwYIF8fzzz8eaNWvivffei169ekX//v1j2LBhMWLEiBgyZEir63/ttdfimWeeiT//+c+xdevWGDJkSIwbNy6GDh3a5ra/8cYb8eSTT8bKlSvjrbfeih49esTuu+8eBx10UHz4wx+O7t27t7jO+fPnx8KFC2PVqlWxZcuW2GOPPWLfffeNMWPGtKq+xrz00kvx3HPPxbJly6KqqioGDRoUxx9/fAwaNKjdXqNTy8rY9ddfn0VE4W+7+++/PzvssMNy/xcRWb9+/XLLP/3009mXvvSlbMSIEfXK1v0bM2ZM9utf/7rZNk2bNi233OLFi5ss/4Mf/CBXvm/fvtmWLVuaXOayyy4rlB88eHCzbWqr2bNn59o4e/bsRsteeOGFhXLDhg0rPH/fffdlBx98cIPrtrKyMrvwwguzFStWNNuWJUuW5JadOnVq4fnTTjstq6qqavA1DjvssOx3v/tdG9cEqXnzzTezBx54IPvJT36S3XTTTdm3vvWt7I477sgef/zxbP369S2u749//GN27733Zt/73veym266Kbv99tuzBx98sFV1NeW1117L/uM//iP7zne+k/3Lv/xLdvfdd2evvvpqu75GR6m7v7j++utL3SQSVVNTk/3qV7/KjjrqqKyysrLR3+devXplJ554YjZ9+vQm69u6dWv2ox/9KNt3330brau6ujo799xzm/3t3m7q1Km55ZcsWZJlWZYtWLAg+7u/+7usoqKiwdc59thjs2effbbRet9+++2sV69ehfJnnXVW0ettu1tvvTX3mrNmzWpxHQ25/PLLc/X+53/+Z7vUSxrWrl2b/b//9/+y3Xffvdm+81577ZVdfvnl2apVq+rV01jfct68ednEiRMb3HYqKiqyCRMmZAsWLGhxuzdv3pzdfvvtjfZFt//17t07O+uss7JHHnmk2TrXr1+f3XTTTdngwYMbra9v377ZZZddVlQfN8sazzZz5szJjj766EZf5yMf+Uj2yiuvNFrvokWLcuW/+MUvFtWe2r74xS/m6njppZdaXEepJRdsr7zyykY/9NrB9u233252g2zob8qUKdmmTZsabdOKFSty5X/84x83+R5OP/30eq/x+OOPN7nM/vvvXyh7/vnnF7/CWqktwbampib7zGc+U9S63XfffbPXX3+9ybY0FGwfffTRbJdddmm2/oqKiuwb3/hGO68dyo3OcGk7w4It7eGNN97Ixo4d26Lf57oHr2tbsWJFgwe8G/vr0aNH4cBpUxralu++++7cdtjY30477ZQ9+OCDjdZ97rnn5trz5ptvtmgdjh49urD8XnvtlW3btq1Fyzdm3Lhxuffx1FNPtUu9lL8//vGP2d57793ivvPvf//7enU11Le89957i9p2unfvnt19991Ft/ull17KDjjggBa1+dBDD212XXzgAx8our4+ffpkM2fObLatDWWbb37zm1m3bt2afY3ddtste+655xqte8yYMYWyAwcObPZEWm1btmzJ9txzz8LyY8eOLXrZcpLUpcg333xzfP/734+IiD59+sTkyZNjv/32i27dusVrr70Wv//97xtcrrKyMg466KA46KCDYsiQIdGnT5/YunVrrFy5Mp555pmYP39+oexdd90VvXv3jltuuaXBuvbcc884+OCD4/nnn4+IiIcffjg+9alPNVh227ZtMWfOnHrPP/zwwzF27NgGl1m6dGm8/PLLhceTJ09usFy5+OpXvxq33nprRLy/biZPnhzDhg2LrVu3xsKFC+Ohhx6KrVu3RkTE4sWL46KLLopHHnkkKioqiqp/2bJl8cUvfjHefvvtiIg4/PDDY8yYMdGvX79YtmxZzJw5s3AvYJZlce2118bOO+8cn/3sZzvg3VJqK1asiLPOOiueeOKJZstu3LgxZs2aFb///e/jzDPPbLDMypUr4+STT465c+c2WdeGDRviF7/4RUyfPj1+/OMfx0UXXdTitt9zzz3xqU99qjBXZEMee+yxGDduXNx///3xd3/3d/X+v3///nHGGWfEL37xi4iIeOCBB+Ktt95q0aX+U6dOLfx7r732avB1oCO9+uqrcdxxx8Xy5ctzz++3335x7LHHxsCBA6OioiJWrVoVzz33XMybNy+2bdvWaH1vvvlmjB07Nl599dXCcxUVFTF27Ng4/PDDo3fv3rF06dKYMWNGvPnmmxERsXnz5rj44otjw4YNcdlllxXd9ocffjguu+yy2Lp1a+y0004xceLEOOCAA6JXr16xePHimDFjRqxduzYiItavXx8f//jH48UXX2zwdqeLL764sC1v3rw57r333rjyyiuLasfChQvj2WefLTy+8MILo7Ky7RNdLFmyJB5//PHC4/79+8eoUaPaXC/lb9OmTXHqqafG0qVLC8/17t07jjvuuPjgBz8Y/fr1i82bN8fbb78dixYtinnz5sW6deuKrv/ZZ5+Na6+9NjZt2hSVlZVx7LHHxujRo6O6ujpeffXVmDlzZmHb2bJlS1x00UXRu3fvOOOMM5qs96mnnoq///u/j7/+9a+550e
"text/plain": [
"<Figure size 1200x900 with 12 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# visualization\n",
"\n",
"from typing import List\n",
"import PIL.Image as PImage\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import transforms\n",
"\n",
"raw_inps, conv1s, conv2s, spconvs = [], [], [], []\n",
"for i in range(N):\n",
" raw_inps.append(raw_inp[i].repeat(3, 1, 1).clamp(1, 255))\n",
" conv1s.append(conv1[i].repeat(3, 1, 1).clamp(1, 255))\n",
" conv2s.append(conv2[i].repeat(3, 1, 1).clamp(1, 255))\n",
" spconvs.append(spconv[i].repeat(3, 1, 1).clamp(1, 255))\n",
"\n",
"n_rows, n_cols = N, 4\n",
"plt.figure(figsize=(n_cols * 3, n_rows * 3))\n",
"\n",
"tensor2pil = transforms.ToPILImage()\n",
"for col, (title, chws) in enumerate([('raw_inp', raw_inps), ('conv1', conv1s), ('conv2', conv2s), ('spconv', spconvs)]):\n",
" p_imgs: List[PImage.Image] = [tensor2pil(t).convert('RGB') for t in chws]\n",
" for row, im in enumerate(p_imgs):\n",
" plt.subplot(n_rows, n_cols, 1 + row * n_cols + col)\n",
" plt.xticks([]), plt.yticks([]), plt.axis('off')\n",
" plt.imshow(im)\n",
" if row == 0:\n",
" plt.title(title, size=28, pad=18)\n",
"plt.show()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}