You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

253 lines
9.6 KiB

2 years ago
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pprint import pformat
from typing import List
import sys
2 years ago
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
import encoder
from decoder import LightDecoder
import debug_tools as D
2 years ago
class SparK(nn.Module):
def __init__(
self,
sparse_encoder: encoder.SparseEncoder,
dense_decoder: LightDecoder,
mask_ratio=0.6,
densify_norm="bn",
sbn=False,
2 years ago
):
super().__init__()
input_size, downsample_raito = (
sparse_encoder.input_size,
sparse_encoder.downsample_raito,
)
2 years ago
self.downsample_raito = downsample_raito
self.fmap_h, self.fmap_w = (
input_size // downsample_raito,
input_size // downsample_raito,
)
self.mask_ratio = mask_ratio
self.len_keep = round(self.fmap_h * self.fmap_w * (1 - mask_ratio))
2 years ago
self.sparse_encoder = sparse_encoder
self.dense_decoder = dense_decoder
2 years ago
self.sbn = sbn
self.hierarchy = len(sparse_encoder.enc_feat_map_chs)
self.densify_norm_str = densify_norm.lower()
self.densify_norms = nn.ModuleList()
self.densify_projs = nn.ModuleList()
2 years ago
self.mask_tokens = nn.ParameterList()
# build the `densify` layers
e_widths, d_width = (
self.sparse_encoder.enc_feat_map_chs,
self.dense_decoder.width,
)
e_widths: List[int]
for i in range(
self.hierarchy
): # from the smallest feat map to the largest; i=0: the last feat map; i=1: the second last feat map ...
e_width = e_widths.pop()
# create mask token
p = nn.Parameter(torch.zeros(1, e_width, 1, 1))
trunc_normal_(p, mean=0, std=0.02, a=-0.02, b=0.02)
self.mask_tokens.append(p)
# create densify norm
if self.densify_norm_str == "bn":
densify_norm = (
encoder.SparseSyncBatchNorm2d
if self.sbn
else encoder.SparseBatchNorm2d
)(e_width)
elif self.densify_norm_str == "ln":
densify_norm = encoder.SparseConvNeXtLayerNorm(
e_width, data_format="channels_first", sparse=True
)
2 years ago
else:
densify_norm = nn.Identity()
self.densify_norms.append(densify_norm)
# create densify proj
if i == 0 and e_width == d_width:
densify_proj = nn.Identity() # todo: NOTE THAT CONVNEXT-S WOULD USE THIS, because it has a width of 768 that equals to the decoder's width 768
print(
f"[SparK.__init__, densify {i+1}/{self.hierarchy}]: use nn.Identity() as densify_proj"
)
else:
kernel_size = 1 if i <= 0 else 3
densify_proj = nn.Conv2d(
e_width,
d_width,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True,
)
print(
f"[SparK.__init__, densify {i+1}/{self.hierarchy}]: densify_proj(ksz={kernel_size}, #para={sum(x.numel() for x in densify_proj.parameters()) / 1e6:.2f}M)"
)
self.densify_projs.append(densify_proj)
# todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
d_width //= 2
print(
f"[SparK.__init__] dims of mask_tokens={tuple(p.numel() for p in self.mask_tokens)}"
)
# these are deprecated and would never be used; can be removed.
self.register_buffer("imn_m", torch.empty(1, 3, 1, 1))
self.register_buffer("imn_s", torch.empty(1, 3, 1, 1))
self.register_buffer("norm_black", torch.zeros(1, 3, input_size, input_size))
2 years ago
self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ...
def mask(self, B: int, device, generator=None):
h, w = self.fmap_h, self.fmap_w
idx = torch.rand(B, h * w, generator=generator).argsort(dim=1)
idx = idx[:, : self.len_keep].to(device) # (B, len_keep)
return (
torch.zeros(B, h * w, dtype=torch.bool, device=device)
.scatter_(dim=1, index=idx, value=True)
.view(B, 1, h, w)
)
def forward(self, inp_bchw: torch.Tensor, active_b1ff=None, vis=False):
# step1. Mask
if active_b1ff is None: # rand mask
active_b1ff: torch.BoolTensor = self.mask(
inp_bchw.shape[0], inp_bchw.device
) # (B, 1, f, f)
encoder._cur_active = active_b1ff # (B, 1, f, f)
active_b1hw = active_b1ff.repeat_interleave(
self.downsample_raito, 2
).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W)
masked_bchw = inp_bchw * active_b1hw
# step2. Encode: get hierarchical encoded sparse features (a list containing 4 feature maps at 4 scales)
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw)
fea_bcffs.reverse() # after reversion: from the smallest feature map to the largest
# step3. Densify: get hierarchical dense features for decoding
cur_active = active_b1ff # (B, 1, f, f)
2 years ago
to_dec = []
for i, bcff in enumerate(
fea_bcffs
): # from the smallest feature map to the largest
2 years ago
if bcff is not None:
bcff = self.densify_norms[i](bcff)
2 years ago
mask_tokens = self.mask_tokens[i].expand_as(bcff)
bcff = torch.where(
cur_active.expand_as(bcff), bcff, mask_tokens
) # fill in empty (non-active) positions with [mask] tokens
bcff: torch.Tensor = self.densify_projs[i](bcff)
2 years ago
to_dec.append(bcff)
cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(
2, dim=3
) # dilate the mask map, from (B, 1, f, f) to (B, 1, H, W)
# step4. Decode and reconstruct
2 years ago
rec_bchw = self.dense_decoder(to_dec)
inp, rec = (
self.patchify(inp_bchw),
self.patchify(rec_bchw),
) # inp and rec: (B, L = f*f, N = C*downsample_raito**2)
mean = inp.mean(dim=-1, keepdim=True)
var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** 0.5
inp = (inp - mean) / var
l2_loss = ((rec - inp) ** 2).mean(
dim=2, keepdim=False
) # (B, L, C) ==mean==> (B, L)
non_active = (
active_b1ff.logical_not().int().view(active_b1ff.shape[0], -1)
) # (B, 1, f, f) => (B, L)
recon_loss = l2_loss.mul_(non_active).sum() / (
non_active.sum() + 1e-8
) # loss only on masked (non-active) patches
if vis:
masked_bchw = inp_bchw * active_b1hw
rec_bchw = self.unpatchify(rec * var + mean)
rec_or_inp = torch.where(active_b1hw, inp_bchw, rec_bchw)
return inp_bchw, masked_bchw, rec_or_inp
else:
return recon_loss
2 years ago
def patchify(self, bchw):
p = self.downsample_raito
h, w = self.fmap_h, self.fmap_w
2 years ago
B, C = bchw.shape[:2]
bchw = bchw.reshape(shape=(B, C, h, p, w, p))
bchw = torch.einsum("bchpwq->bhwpqc", bchw)
bln = bchw.reshape(
shape=(B, h * w, C * p**2)
) # (B, f*f, 3*downsample_raito**2)
2 years ago
return bln
2 years ago
def unpatchify(self, bln):
p = self.downsample_raito
h, w = self.fmap_h, self.fmap_w
B, C = bln.shape[0], bln.shape[-1] // p**2
2 years ago
bln = bln.reshape(shape=(B, h, w, p, p, C))
bln = torch.einsum("bhwpqc->bchpwq", bln)
2 years ago
bchw = bln.reshape(shape=(B, C, h * p, w * p))
return bchw
2 years ago
def __repr__(self):
return (
f'\n'
f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n'
f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}'
2 years ago
)
2 years ago
def get_config(self):
return {
# self
"mask_ratio": self.mask_ratio,
"densify_norm_str": self.densify_norm_str,
"sbn": self.sbn,
"hierarchy": self.hierarchy,
2 years ago
# enc
"sparse_encoder.input_size": self.sparse_encoder.input_size,
2 years ago
# dec
"dense_decoder.width": self.dense_decoder.width,
2 years ago
}
def state_dict(
self, destination=None, prefix="", keep_vars=False, with_config=False
):
state = super(SparK, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
2 years ago
if with_config:
state["config"] = self.get_config()
2 years ago
return state
2 years ago
def load_state_dict(self, state_dict, strict=True):
config: dict = state_dict.pop("config", None)
incompatible_keys = super(SparK, self).load_state_dict(
state_dict, strict=strict
)
2 years ago
if config is not None:
for k, v in self.get_config().items():
ckpt_v = config.get(k, None)
if ckpt_v != v:
err = f"[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={ckpt_v})"
2 years ago
if strict:
raise AttributeError(err)
else:
print(err, file=sys.stderr)
2 years ago
return incompatible_keys