You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
314 lines
13 KiB
314 lines
13 KiB
2 years ago
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import random
|
||
|
from pprint import pformat
|
||
|
from typing import List
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
||
|
from timm.models.layers import trunc_normal_
|
||
|
|
||
|
import encoder
|
||
|
from decoder import LightDecoder
|
||
|
|
||
|
|
||
|
class SparK(nn.Module):
|
||
|
def __init__(
|
||
|
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder,
|
||
|
mask_ratio=0.6, mask_ratio2=0.6, uniform=False,
|
||
|
using_pe=True, pix_norm=0, dense_loss=False, loss_l2=True,
|
||
|
en_de_norm='bn', en_de_lin=True, sbn=False, pyramid=1, # 1 for single-scale pre-training; 4 for full-scale pre-training
|
||
|
):
|
||
|
super().__init__()
|
||
|
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito
|
||
|
self.downsample_raito = downsample_raito
|
||
|
fmap_size = input_size // downsample_raito
|
||
|
self.fmap_size = fmap_size
|
||
|
if mask_ratio != mask_ratio2 and not uniform: # with an extra active site
|
||
|
k = 1 / fmap_size**2
|
||
|
mask_ratio = min(1, mask_ratio / (1-k))
|
||
|
mask_ratio2 = min(1, mask_ratio2 / (1-k))
|
||
|
self.mask_ratio = (mask_ratio, mask_ratio2)
|
||
|
self.ratios = torch.tensor([self.mask_ratio[0], self.mask_ratio[1], (self.mask_ratio[0] + self.mask_ratio[1]) / 2])
|
||
|
self.uniform = uniform
|
||
|
self.len_keep = round(fmap_size * fmap_size * (1-mask_ratio))
|
||
|
self.pix_norm = int(pix_norm)
|
||
|
|
||
|
self.sparse_encoder = sparse_encoder
|
||
|
self.dense_decoder = dense_decoder
|
||
|
|
||
|
self.sbn = sbn
|
||
|
self.pyramid = pyramid
|
||
|
en_de_norm = en_de_norm.lower()
|
||
|
self.en_de_norm_str = en_de_norm
|
||
|
|
||
|
self.en_de_lin_bool = en_de_lin
|
||
|
self.en_de_norms = nn.ModuleList()
|
||
|
self.en_de_lins = nn.ModuleList()
|
||
|
self.using_pe = using_pe
|
||
|
self.pos_embeds = nn.ParameterList()
|
||
|
self.mask_tokens = nn.ParameterList()
|
||
|
|
||
|
fea, d_fea, fmap = self.sparse_encoder.fea_dim, self.dense_decoder.fea_dim, fmap_size
|
||
|
for i in range(self.pyramid):
|
||
|
if en_de_norm == 'bn':
|
||
|
n = (encoder.SparseSyncBatchNorm2d if sbn else encoder.SparseBatchNorm2d)(fea)
|
||
|
elif en_de_norm == 'ln':
|
||
|
n = encoder.SparseConvNeXtLayerNorm(fea, data_format='channels_first', sparse=True)
|
||
|
else:
|
||
|
n = nn.Identity()
|
||
|
self.en_de_norms.append(n)
|
||
|
|
||
|
kernel_size = 1 if i <= 0 else 3
|
||
|
l = nn.Conv2d(fea, d_fea, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True)
|
||
|
print(f'[mid, py={self.pyramid}][edl {i}]: k={kernel_size}, #para = {sum(x.numel() for x in l.parameters())/1e6:.2f}')
|
||
|
if i == 0 and fea == d_fea:
|
||
|
l = nn.Identity()
|
||
|
self.en_de_lins.append(l)
|
||
|
|
||
|
if self.using_pe:
|
||
|
p = torch.from_numpy(get_2d_sincos_pos_embed(fea, fmap)).float()
|
||
|
p = p.reshape(1, fmap, fmap, fea).permute(0, 3, 1, 2).contiguous()
|
||
|
p = nn.Parameter(p, requires_grad=False)
|
||
|
self.pos_embeds.append(p)
|
||
|
|
||
|
p = nn.Parameter(torch.zeros(1, fea, 1, 1))
|
||
|
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02)
|
||
|
self.mask_tokens.append(p)
|
||
|
fea //= 2
|
||
|
d_fea //= 2
|
||
|
fmap *= 2
|
||
|
|
||
|
print(f'[mid, py={self.pyramid}][mask_tokens]: {tuple(p.numel() for p in self.mask_tokens)}')
|
||
|
|
||
|
self.loss_l2, self.dense_loss = loss_l2, dense_loss
|
||
|
|
||
|
m = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1)
|
||
|
s = torch.tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1)
|
||
|
self.register_buffer('imn_m', m)
|
||
|
self.register_buffer('imn_s', s)
|
||
|
# self.register_buffer('norm_black', (torch.ones(1, 3, input_size, input_size) * 0.45 - m) / s)
|
||
|
self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size))
|
||
|
self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ...
|
||
|
|
||
|
def mask(self, shape, device, generator=None):
|
||
|
B, C, H, W = shape
|
||
|
f = self.fmap_size
|
||
|
if self.mask_ratio[0] == self.mask_ratio[1]:
|
||
|
len_keep = self.len_keep
|
||
|
elif self.uniform:
|
||
|
r = random.uniform(self.mask_ratio[0], self.mask_ratio[1])
|
||
|
len_keep = round(f * f * (1-r))
|
||
|
else:
|
||
|
i1, i2, i3, i4 = np.linspace(0, B, 4, dtype=int).tolist()
|
||
|
l1, l2, l3 = i2-i1, i3-i2, i4-i3
|
||
|
r1, r2, r3 = self.ratios[torch.randperm(3, generator=generator)].tolist()
|
||
|
r = torch.tensor([r1]*l1 + [r2]*l2 + [r3]*l3, device=device).view(-1, 1, 1)
|
||
|
active = torch.rand(B, f, f, device=device, generator=generator) >= r
|
||
|
|
||
|
rr, cc = torch.randint(low=0, high=f, size=(2, B), generator=generator).unbind(0)
|
||
|
active[torch.arange(B), rr, cc] = True # an extra active site
|
||
|
return active.unsqueeze_(1)
|
||
|
|
||
|
idx = torch.rand(B, f*f, generator=generator).argsort(dim=1)
|
||
|
idx = idx[:, :len_keep].to(device) # (B, len_keep)
|
||
|
return torch.zeros(B, f*f, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, f, f)
|
||
|
|
||
|
def forward(self, raw_inp: torch.Tensor, active=None):
|
||
|
inp_bchw = raw_inp
|
||
|
|
||
|
# spatial mask
|
||
|
if active is None:
|
||
|
active: torch.BoolTensor = self.mask(inp_bchw.shape, inp_bchw.device) # (B, 1, f, f)
|
||
|
encoder._cur_active = active
|
||
|
active_ex = active.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W)
|
||
|
masked_bchw = inp_bchw * active_ex
|
||
|
|
||
|
# get hierarchical encoded sparse features (a list containing four feature maps)
|
||
|
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, pyramid=self.pyramid)
|
||
|
fea_bcffs.reverse() # from the smallest feature map to the largest
|
||
|
|
||
|
cur_active = active
|
||
|
to_dec = []
|
||
|
for i, bcff in enumerate(fea_bcffs): # from the smallest feature map to the largest
|
||
|
if bcff is not None:
|
||
|
# fill in empty positions with [mask] embeddings
|
||
|
bcff = self.en_de_norms[i](bcff)
|
||
|
mask_tokens = self.mask_tokens[i].expand_as(bcff)
|
||
|
if self.using_pe:
|
||
|
mask_tokens = mask_tokens + self.pos_embeds[i].expand_as(bcff)
|
||
|
bcff = torch.where(cur_active.expand_as(bcff), bcff, mask_tokens)
|
||
|
bcff: torch.Tensor = self.en_de_lins[i](bcff)
|
||
|
to_dec.append(bcff)
|
||
|
cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) # dilate the mask map
|
||
|
|
||
|
# decode and reconstruct
|
||
|
rec_bchw = self.dense_decoder(to_dec)
|
||
|
|
||
|
# calc loss
|
||
|
mean, var, spatial_loss = self.spatial_loss(raw_inp, rec_bchw, active)
|
||
|
return active_ex, rec_bchw, spatial_loss
|
||
|
|
||
|
def spatial_loss(self, inp, rec, active): # active: (B, 1, f, f)
|
||
|
mean = var = None
|
||
|
if self.pix_norm == 2:
|
||
|
mean, var = inp.mean(dim=(2, 3), keepdim=True), None
|
||
|
rec = rec + mean
|
||
|
inp = self.patchify(inp)
|
||
|
rec = self.patchify(rec)
|
||
|
# (B, L=fmap_size**2, N=downsample_raito**2 * C)
|
||
|
if self.pix_norm == 1:
|
||
|
mean = inp.mean(dim=-1, keepdim=True)
|
||
|
var = (inp.var(dim=-1, keepdim=True) + 1e-6)**.5
|
||
|
inp = (inp - mean) / var
|
||
|
loss_spa = (rec-inp)**2 if self.loss_l2 else (rec-inp).abs()
|
||
|
|
||
|
if self.dense_loss:
|
||
|
return mean, var, loss_spa.mean() # mean loss on all patches
|
||
|
else:
|
||
|
loss_spa = loss_spa.mean(dim=2, keepdim=False) # (B, L, C) => (B, L)
|
||
|
non_active = active.logical_not().int().view(active.shape[0], -1) # (B, 1, f, f) => (B, L)
|
||
|
return mean, var, loss_spa.mul_(non_active).sum() / (non_active.sum() + 1e-8) # mean loss on removed patches
|
||
|
|
||
|
def patchify(self, bchw):
|
||
|
p = self.downsample_raito
|
||
|
h = w = self.fmap_size
|
||
|
B, C = bchw.shape[:2]
|
||
|
bchw = bchw.reshape(shape=(B, C, h, p, w, p))
|
||
|
bchw = torch.einsum('bchpwq->bhwpqc', bchw)
|
||
|
bln = bchw.reshape(shape=(B, h*w, p**2 * C)) # (B, f*f, downsample_raito*downsample_raito*3)
|
||
|
return bln
|
||
|
|
||
|
def unpatchify(self, bln):
|
||
|
p = self.downsample_raito
|
||
|
h = w = self.fmap_size
|
||
|
B, C = bln.shape[0], bln.shape[-1] // p**2
|
||
|
bln = bln.reshape(shape=(B, h, w, p, p, C))
|
||
|
bln = torch.einsum('bhwpqc->bchpwq', bln)
|
||
|
bchw = bln.reshape(shape=(B, C, h * p, w * p))
|
||
|
return bchw
|
||
|
|
||
|
def denorm_for_vis(self, x, clamp):
|
||
|
x = x * self.imn_s
|
||
|
x += self.imn_m
|
||
|
if clamp:
|
||
|
x = torch.clamp(x, 0, 1)
|
||
|
return x
|
||
|
|
||
|
def __repr__(self):
|
||
|
return (
|
||
|
f'\n'
|
||
|
f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n'
|
||
|
f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}\n'
|
||
|
f'[SparK.dec]: {self.dense_decoder.num_para()}'
|
||
|
)
|
||
|
|
||
|
def get_config(self):
|
||
|
return {
|
||
|
# self
|
||
|
'mask_ratio': self.mask_ratio[0], 'mask_ratio2': self.mask_ratio[1], 'uniform': self.uniform,
|
||
|
'using_pe': self.using_pe, 'pix_norm': self.pix_norm,
|
||
|
'dense_loss': self.dense_loss, 'loss_l2': self.loss_l2,
|
||
|
'en_de_norm': self.en_de_norm_str, 'en_de_lin': self.en_de_lin_bool, 'sbn': self.sbn, 'pyramid': self.pyramid,
|
||
|
|
||
|
# enc
|
||
|
'input_size': self.sparse_encoder.input_size,
|
||
|
# dec
|
||
|
'dec_fea_dim': self.dense_decoder.fea_dim, 'double': self.dense_decoder.double_bool, 'heavy': self.dense_decoder.heavy,
|
||
|
}
|
||
|
|
||
|
def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False):
|
||
|
state = super(SparK, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||
|
if with_config:
|
||
|
state['config'] = self.get_config() # todo: 似乎会引起DDP broadcast err??
|
||
|
return state
|
||
|
|
||
|
def load_state_dict(self, state_dict, strict=True):
|
||
|
config = state_dict.pop('config', None)
|
||
|
incompatible_keys = super(SparK, self).load_state_dict(state_dict, strict=strict)
|
||
|
if config is not None:
|
||
|
for k, v in self.get_config().items():
|
||
|
if config.get(k, None) != v:
|
||
|
err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
|
||
|
if strict:
|
||
|
raise AttributeError(err)
|
||
|
else:
|
||
|
print(err)
|
||
|
|
||
|
return incompatible_keys
|
||
|
|
||
|
|
||
|
def _make_divisible(v, divisor=8, min_value=None):
|
||
|
"""
|
||
|
This function is taken from the original tf repo.
|
||
|
It ensures that all layers have a channel number that is divisible by 8
|
||
|
It can be seen here:
|
||
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||
|
"""
|
||
|
if min_value is None:
|
||
|
min_value = divisor
|
||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||
|
# Make sure that round down does not go down by more than 10%.
|
||
|
if new_v < 0.9 * v:
|
||
|
new_v += divisor
|
||
|
return new_v
|
||
|
|
||
|
|
||
|
def get_2d_sincos_pos_embed(embed_dim, grid_size):
|
||
|
"""
|
||
|
grid_size: int of the grid height and width
|
||
|
return:
|
||
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||
|
"""
|
||
|
grid_h = np.arange(grid_size, dtype=np.float32)
|
||
|
grid_w = np.arange(grid_size, dtype=np.float32)
|
||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||
|
grid = np.stack(grid, axis=0)
|
||
|
|
||
|
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||
|
return pos_embed
|
||
|
|
||
|
|
||
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||
|
assert embed_dim % 2 == 0
|
||
|
|
||
|
# use half of dimensions to encode grid_h
|
||
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||
|
|
||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||
|
return emb
|
||
|
|
||
|
|
||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||
|
"""
|
||
|
embed_dim: output dimension for each position
|
||
|
pos: a list of positions to be encoded: size (M,)
|
||
|
out: (M, D)
|
||
|
"""
|
||
|
assert embed_dim % 2 == 0
|
||
|
omega = np.arange(embed_dim // 2, dtype=np.float)
|
||
|
omega /= embed_dim / 2.
|
||
|
omega = 1. / 10000**omega # (D/2,)
|
||
|
|
||
|
pos = pos.reshape(-1) # (M,)
|
||
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||
|
|
||
|
emb_sin = np.sin(out) # (M, D/2)
|
||
|
emb_cos = np.cos(out) # (M, D/2)
|
||
|
|
||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||
|
return emb
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
SparK.test_mask()
|
||
|
SparK.test_align()
|