You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
313 lines
13 KiB
313 lines
13 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import random |
|
from pprint import pformat |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN |
|
from timm.models.layers import trunc_normal_ |
|
|
|
import encoder |
|
from decoder import LightDecoder |
|
|
|
|
|
class SparK(nn.Module): |
|
def __init__( |
|
self, sparse_encoder: encoder.SparseEncoder, dense_decoder: LightDecoder, |
|
mask_ratio=0.6, mask_ratio2=0.6, uniform=False, |
|
using_pe=True, pix_norm=0, dense_loss=False, loss_l2=True, |
|
en_de_norm='bn', en_de_lin=True, sbn=False, pyramid=1, # 1 for single-scale pre-training; 4 for full-scale pre-training |
|
): |
|
super().__init__() |
|
input_size, downsample_raito = sparse_encoder.input_size, sparse_encoder.downsample_raito |
|
self.downsample_raito = downsample_raito |
|
fmap_size = input_size // downsample_raito |
|
self.fmap_size = fmap_size |
|
if mask_ratio != mask_ratio2 and not uniform: # with an extra active site |
|
k = 1 / fmap_size**2 |
|
mask_ratio = min(1, mask_ratio / (1-k)) |
|
mask_ratio2 = min(1, mask_ratio2 / (1-k)) |
|
self.mask_ratio = (mask_ratio, mask_ratio2) |
|
self.ratios = torch.tensor([self.mask_ratio[0], self.mask_ratio[1], (self.mask_ratio[0] + self.mask_ratio[1]) / 2]) |
|
self.uniform = uniform |
|
self.len_keep = round(fmap_size * fmap_size * (1-mask_ratio)) |
|
self.pix_norm = int(pix_norm) |
|
|
|
self.sparse_encoder = sparse_encoder |
|
self.dense_decoder = dense_decoder |
|
|
|
self.sbn = sbn |
|
self.pyramid = pyramid |
|
en_de_norm = en_de_norm.lower() |
|
self.en_de_norm_str = en_de_norm |
|
|
|
self.en_de_lin_bool = en_de_lin |
|
self.en_de_norms = nn.ModuleList() |
|
self.en_de_lins = nn.ModuleList() |
|
self.using_pe = using_pe |
|
self.pos_embeds = nn.ParameterList() |
|
self.mask_tokens = nn.ParameterList() |
|
|
|
fea, d_fea, fmap = self.sparse_encoder.fea_dim, self.dense_decoder.fea_dim, fmap_size |
|
for i in range(self.pyramid): |
|
if en_de_norm == 'bn': |
|
n = (encoder.SparseSyncBatchNorm2d if sbn else encoder.SparseBatchNorm2d)(fea) |
|
elif en_de_norm == 'ln': |
|
n = encoder.SparseConvNeXtLayerNorm(fea, data_format='channels_first', sparse=True) |
|
else: |
|
n = nn.Identity() |
|
self.en_de_norms.append(n) |
|
|
|
kernel_size = 1 if i <= 0 else 3 |
|
l = nn.Conv2d(fea, d_fea, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True) |
|
print(f'[mid, py={self.pyramid}][edl {i}]: k={kernel_size}, #para = {sum(x.numel() for x in l.parameters())/1e6:.2f}') |
|
if i == 0 and fea == d_fea: |
|
l = nn.Identity() |
|
self.en_de_lins.append(l) |
|
|
|
if self.using_pe: |
|
p = torch.from_numpy(get_2d_sincos_pos_embed(fea, fmap)).float() |
|
p = p.reshape(1, fmap, fmap, fea).permute(0, 3, 1, 2).contiguous() |
|
p = nn.Parameter(p, requires_grad=False) |
|
self.pos_embeds.append(p) |
|
|
|
p = nn.Parameter(torch.zeros(1, fea, 1, 1)) |
|
trunc_normal_(p, mean=0, std=.02, a=-.02, b=.02) |
|
self.mask_tokens.append(p) |
|
fea //= 2 |
|
d_fea //= 2 |
|
fmap *= 2 |
|
|
|
print(f'[mid, py={self.pyramid}][mask_tokens]: {tuple(p.numel() for p in self.mask_tokens)}') |
|
|
|
self.loss_l2, self.dense_loss = loss_l2, dense_loss |
|
|
|
m = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, 3, 1, 1) |
|
s = torch.tensor(IMAGENET_DEFAULT_STD).view(1, 3, 1, 1) |
|
self.register_buffer('imn_m', m) |
|
self.register_buffer('imn_s', s) |
|
# self.register_buffer('norm_black', (torch.ones(1, 3, input_size, input_size) * 0.45 - m) / s) |
|
self.register_buffer('norm_black', torch.zeros(1, 3, input_size, input_size)) |
|
self.vis_active = self.vis_active_ex = self.vis_inp = self.vis_inp_mask = ... |
|
|
|
def mask(self, shape, device, generator=None): |
|
B, C, H, W = shape |
|
f = self.fmap_size |
|
if self.mask_ratio[0] == self.mask_ratio[1]: |
|
len_keep = self.len_keep |
|
elif self.uniform: |
|
r = random.uniform(self.mask_ratio[0], self.mask_ratio[1]) |
|
len_keep = round(f * f * (1-r)) |
|
else: |
|
i1, i2, i3, i4 = np.linspace(0, B, 4, dtype=int).tolist() |
|
l1, l2, l3 = i2-i1, i3-i2, i4-i3 |
|
r1, r2, r3 = self.ratios[torch.randperm(3, generator=generator)].tolist() |
|
r = torch.tensor([r1]*l1 + [r2]*l2 + [r3]*l3, device=device).view(-1, 1, 1) |
|
active = torch.rand(B, f, f, device=device, generator=generator) >= r |
|
|
|
rr, cc = torch.randint(low=0, high=f, size=(2, B), generator=generator).unbind(0) |
|
active[torch.arange(B), rr, cc] = True # an extra active site |
|
return active.unsqueeze_(1) |
|
|
|
idx = torch.rand(B, f*f, generator=generator).argsort(dim=1) |
|
idx = idx[:, :len_keep].to(device) # (B, len_keep) |
|
return torch.zeros(B, f*f, dtype=torch.bool, device=device).scatter_(dim=1, index=idx, value=True).view(B, 1, f, f) |
|
|
|
def forward(self, raw_inp: torch.Tensor, active=None): |
|
inp_bchw = raw_inp |
|
|
|
# spatial mask |
|
if active is None: |
|
active: torch.BoolTensor = self.mask(inp_bchw.shape, inp_bchw.device) # (B, 1, f, f) |
|
encoder._cur_active = active |
|
active_ex = active.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) # (B, 1, H, W) |
|
masked_bchw = inp_bchw * active_ex |
|
|
|
# get hierarchical encoded sparse features (a list containing four feature maps) |
|
fea_bcffs: List[torch.Tensor] = self.sparse_encoder(masked_bchw, pyramid=self.pyramid) |
|
fea_bcffs.reverse() # from the smallest feature map to the largest |
|
|
|
cur_active = active |
|
to_dec = [] |
|
for i, bcff in enumerate(fea_bcffs): # from the smallest feature map to the largest |
|
if bcff is not None: |
|
# fill in empty positions with [mask] embeddings |
|
bcff = self.en_de_norms[i](bcff) |
|
mask_tokens = self.mask_tokens[i].expand_as(bcff) |
|
if self.using_pe: |
|
mask_tokens = mask_tokens + self.pos_embeds[i].expand_as(bcff) |
|
bcff = torch.where(cur_active.expand_as(bcff), bcff, mask_tokens) |
|
bcff: torch.Tensor = self.en_de_lins[i](bcff) |
|
to_dec.append(bcff) |
|
cur_active = cur_active.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) # dilate the mask map |
|
|
|
# decode and reconstruct |
|
rec_bchw = self.dense_decoder(to_dec) |
|
|
|
# calc loss |
|
mean, var, spatial_loss = self.spatial_loss(raw_inp, rec_bchw, active) |
|
return active_ex, rec_bchw, spatial_loss |
|
|
|
def spatial_loss(self, inp, rec, active): # active: (B, 1, f, f) |
|
mean = var = None |
|
if self.pix_norm == 2: |
|
mean, var = inp.mean(dim=(2, 3), keepdim=True), None |
|
rec = rec + mean |
|
inp = self.patchify(inp) |
|
rec = self.patchify(rec) |
|
# (B, L=fmap_size**2, N=downsample_raito**2 * C) |
|
if self.pix_norm == 1: |
|
mean = inp.mean(dim=-1, keepdim=True) |
|
var = (inp.var(dim=-1, keepdim=True) + 1e-6)**.5 |
|
inp = (inp - mean) / var |
|
loss_spa = (rec-inp)**2 if self.loss_l2 else (rec-inp).abs() |
|
|
|
if self.dense_loss: |
|
return mean, var, loss_spa.mean() # mean loss on all patches |
|
else: |
|
loss_spa = loss_spa.mean(dim=2, keepdim=False) # (B, L, C) => (B, L) |
|
non_active = active.logical_not().int().view(active.shape[0], -1) # (B, 1, f, f) => (B, L) |
|
return mean, var, loss_spa.mul_(non_active).sum() / (non_active.sum() + 1e-8) # mean loss on removed patches |
|
|
|
def patchify(self, bchw): |
|
p = self.downsample_raito |
|
h = w = self.fmap_size |
|
B, C = bchw.shape[:2] |
|
bchw = bchw.reshape(shape=(B, C, h, p, w, p)) |
|
bchw = torch.einsum('bchpwq->bhwpqc', bchw) |
|
bln = bchw.reshape(shape=(B, h*w, p**2 * C)) # (B, f*f, downsample_raito*downsample_raito*3) |
|
return bln |
|
|
|
def unpatchify(self, bln): |
|
p = self.downsample_raito |
|
h = w = self.fmap_size |
|
B, C = bln.shape[0], bln.shape[-1] // p**2 |
|
bln = bln.reshape(shape=(B, h, w, p, p, C)) |
|
bln = torch.einsum('bhwpqc->bchpwq', bln) |
|
bchw = bln.reshape(shape=(B, C, h * p, w * p)) |
|
return bchw |
|
|
|
def denorm_for_vis(self, x, clamp): |
|
x = x * self.imn_s |
|
x += self.imn_m |
|
if clamp: |
|
x = torch.clamp(x, 0, 1) |
|
return x |
|
|
|
def __repr__(self): |
|
return ( |
|
f'\n' |
|
f'[SparK.config]: {pformat(self.get_config(), indent=2, width=250)}\n' |
|
f'[SparK.structure]: {super(SparK, self).__repr__().replace(SparK.__name__, "")}\n' |
|
f'[SparK.dec]: {self.dense_decoder.num_para()}' |
|
) |
|
|
|
def get_config(self): |
|
return { |
|
# self |
|
'mask_ratio': self.mask_ratio[0], 'mask_ratio2': self.mask_ratio[1], 'uniform': self.uniform, |
|
'using_pe': self.using_pe, 'pix_norm': self.pix_norm, |
|
'dense_loss': self.dense_loss, 'loss_l2': self.loss_l2, |
|
'en_de_norm': self.en_de_norm_str, 'en_de_lin': self.en_de_lin_bool, 'sbn': self.sbn, 'pyramid': self.pyramid, |
|
|
|
# enc |
|
'input_size': self.sparse_encoder.input_size, |
|
# dec |
|
'dec_fea_dim': self.dense_decoder.fea_dim, 'double': self.dense_decoder.double_bool, 'heavy': self.dense_decoder.heavy, |
|
} |
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False, with_config=False): |
|
state = super(SparK, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
|
if with_config: |
|
state['config'] = self.get_config() # todo: 似乎会引起DDP broadcast err?? |
|
return state |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
config = state_dict.pop('config', None) |
|
incompatible_keys = super(SparK, self).load_state_dict(state_dict, strict=strict) |
|
if config is not None: |
|
for k, v in self.get_config().items(): |
|
if config.get(k, None) != v: |
|
err = f'[SparseMIM.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})' |
|
if strict: |
|
raise AttributeError(err) |
|
else: |
|
print(err) |
|
|
|
return incompatible_keys |
|
|
|
|
|
def _make_divisible(v, divisor=8, min_value=None): |
|
""" |
|
This function is taken from the original tf repo. |
|
It ensures that all layers have a channel number that is divisible by 8 |
|
It can be seen here: |
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
|
""" |
|
if min_value is None: |
|
min_value = divisor |
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
|
# Make sure that round down does not go down by more than 10%. |
|
if new_v < 0.9 * v: |
|
new_v += divisor |
|
return new_v |
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
# use half of dimensions to encode grid_h |
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000**omega # (D/2,) |
|
|
|
pos = pos.reshape(-1) # (M,) |
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product |
|
|
|
emb_sin = np.sin(out) # (M, D/2) |
|
emb_cos = np.cos(out) # (M, D/2) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) |
|
return emb |
|
|
|
|
|
if __name__ == '__main__': |
|
SparK.test_mask() |
|
SparK.test_align()
|
|
|