|
|
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from timm.models.layers import DropPath
|
|
|
|
|
|
|
|
|
|
|
|
_cur_active: torch.Tensor = None # B1ff
|
|
|
|
# todo: try to use `gather` for speed?
|
|
|
|
def _get_active_ex_or_ii(H, returning_active_ex=True):
|
|
|
|
downsample_raito = H // _cur_active.shape[-1]
|
|
|
|
active_ex = _cur_active.repeat_interleave(downsample_raito, 2).repeat_interleave(downsample_raito, 3)
|
|
|
|
return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
|
|
|
|
|
|
|
|
|
|
|
|
def sp_conv_forward(self, x: torch.Tensor):
|
|
|
|
x = super(type(self), self).forward(x)
|
|
|
|
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def sp_bn_forward(self, x: torch.Tensor):
|
|
|
|
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False)
|
|
|
|
|
|
|
|
bhwc = x.permute(0, 2, 3, 1)
|
|
|
|
nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc`
|
|
|
|
nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
|
|
|
|
|
|
|
|
bchw = torch.zeros_like(bhwc)
|
|
|
|
bchw[ii] = nc
|
|
|
|
bchw = bchw.permute(0, 3, 1, 2)
|
|
|
|
return bchw
|
|
|
|
|
|
|
|
|
|
|
|
class SparseConv2d(nn.Conv2d):
|
|
|
|
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
|
|
|
|
|
|
|
|
|
|
|
class SparseMaxPooling(nn.MaxPool2d):
|
|
|
|
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
|
|
|
|
|
|
|
|
|
|
|
class SparseAvgPooling(nn.AvgPool2d):
|
|
|
|
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
|
|
|
|
|
|
|
|
|
|
|
class SparseBatchNorm2d(nn.BatchNorm1d):
|
|
|
|
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
|
|
|
|
|
|
|
|
|
|
|
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
|
|
|
|
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
|
|
|
|
|
|
|
|
|
|
|
class SparseConvNeXtLayerNorm(nn.LayerNorm):
|
|
|
|
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
|
|
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
|
|
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
|
|
|
with shape (batch_size, channels, height, width).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
|
|
|
|
if data_format not in ["channels_last", "channels_first"]:
|
|
|
|
raise NotImplementedError
|
|
|
|
super().__init__(normalized_shape, eps, elementwise_affine=True)
|
|
|
|
self.data_format = data_format
|
|
|
|
self.sparse = sparse
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
if x.ndim == 4: # BHWC or BCHW
|
|
|
|
if self.data_format == "channels_last": # BHWC
|
|
|
|
if self.sparse:
|
|
|
|
ii = _get_active_ex_or_ii(H=x.shape[1], returning_active_ex=False)
|
|
|
|
nc = x[ii]
|
|
|
|
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
|
|
|
|
|
|
|
x = torch.zeros_like(x)
|
|
|
|
x[ii] = nc
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
|
|
|
else: # channels_first, BCHW
|
|
|
|
if self.sparse:
|
|
|
|
ii = _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=False)
|
|
|
|
bhwc = x.permute(0, 2, 3, 1)
|
|
|
|
nc = bhwc[ii]
|
|
|
|
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
|
|
|
|
|
|
|
x = torch.zeros_like(bhwc)
|
|
|
|
x[ii] = nc
|
|
|
|
return x.permute(0, 3, 1, 2)
|
|
|
|
else:
|
|
|
|
u = x.mean(1, keepdim=True)
|
|
|
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
|
|
x = (x - u) / torch.sqrt(s + self.eps)
|
|
|
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
|
|
|
return x
|
|
|
|
else: # BLC or BC
|
|
|
|
if self.sparse:
|
|
|
|
raise NotImplementedError
|
|
|
|
else:
|
|
|
|
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
|
|
|
|
|
|
|
|
|
|
|
|
class SparseConvNeXtBlock(nn.Module):
|
|
|
|
r""" ConvNeXt Block. There are two equivalent implementations:
|
|
|
|
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
|
|
|
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
|
|
|
We use (2) as we find it slightly faster in PyTorch
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim (int): Number of input channels.
|
|
|
|
drop_path (float): Stochastic depth rate. Default: 0.0
|
|
|
|
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
|
|
|
|
super().__init__()
|
|
|
|
self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv
|
|
|
|
self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
|
|
|
|
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
|
|
|
self.act = nn.GELU()
|
|
|
|
self.pwconv2 = nn.Linear(4 * dim, dim)
|
|
|
|
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
|
|
|
requires_grad=True) if layer_scale_init_value > 0 else None
|
|
|
|
self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.sparse = sparse
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
input = x
|
|
|
|
x = self.dwconv(x)
|
|
|
|
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
|
|
x = self.norm(x)
|
|
|
|
x = self.pwconv1(x)
|
|
|
|
x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
|
|
|
|
x = self.pwconv2(x)
|
|
|
|
if self.gamma is not None:
|
|
|
|
x = self.gamma * x
|
|
|
|
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
|
|
|
|
|
|
|
if self.sparse:
|
|
|
|
x *= _get_active_ex_or_ii(H=x.shape[2], returning_active_ex=True)
|
|
|
|
|
|
|
|
x = input + self.drop_path(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
|
|
|
|
|
|
|
|
|
|
|
|
class SparseEncoder(nn.Module):
|
|
|
|
def __init__(self, conv_model, input_size, downsample_raito, encoder_fea_dim, sbn=False, verbose=False):
|
|
|
|
super(SparseEncoder, self).__init__()
|
|
|
|
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=conv_model, verbose=verbose, sbn=sbn)
|
|
|
|
self.input_size, self.downsample_raito, self.fea_dim = input_size, downsample_raito, encoder_fea_dim
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
|
|
|
|
oup = m
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
m: nn.Conv2d
|
|
|
|
bias = m.bias is not None
|
|
|
|
oup = SparseConv2d(
|
|
|
|
m.in_channels, m.out_channels,
|
|
|
|
kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
|
|
|
|
dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
|
|
|
|
)
|
|
|
|
oup.weight.data.copy_(m.weight.data)
|
|
|
|
if bias:
|
|
|
|
oup.bias.data.copy_(m.bias.data)
|
|
|
|
elif isinstance(m, nn.MaxPool2d):
|
|
|
|
m: nn.MaxPool2d
|
|
|
|
oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode)
|
|
|
|
elif isinstance(m, nn.AvgPool2d):
|
|
|
|
m: nn.AvgPool2d
|
|
|
|
oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
|
|
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
|
|
|
|
m: nn.BatchNorm2d
|
|
|
|
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats)
|
|
|
|
oup.weight.data.copy_(m.weight.data)
|
|
|
|
oup.bias.data.copy_(m.bias.data)
|
|
|
|
oup.running_mean.data.copy_(m.running_mean.data)
|
|
|
|
oup.running_var.data.copy_(m.running_var.data)
|
|
|
|
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
|
|
|
|
if hasattr(m, "qconfig"):
|
|
|
|
oup.qconfig = m.qconfig
|
|
|
|
elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
|
|
|
|
m: nn.LayerNorm
|
|
|
|
oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
|
|
|
|
oup.weight.data.copy_(m.weight.data)
|
|
|
|
oup.bias.data.copy_(m.bias.data)
|
|
|
|
elif isinstance(m, (nn.Conv1d,)):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
for name, child in m.named_children():
|
|
|
|
oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
|
|
|
|
del m
|
|
|
|
return oup
|
|
|
|
|
|
|
|
def forward(self, x, hierarchy):
|
|
|
|
return self.sp_cnn(x, hierarchy=hierarchy)
|