You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
319 lines
11 KiB
319 lines
11 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import torch |
|
import torch.nn as nn |
|
from timm.models.layers import DropPath |
|
|
|
|
|
_cur_active: torch.Tensor = None # B1ff |
|
|
|
|
|
# todo: try to use `gather` for speed? |
|
def _get_active_ex_or_ii(H, W, returning_active_ex=True): |
|
h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1] |
|
active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave( |
|
w_repeat, dim=3 |
|
) |
|
return ( |
|
active_ex |
|
if returning_active_ex |
|
else active_ex.squeeze(1).nonzero(as_tuple=True) |
|
) # ii: bi, hi, wi |
|
|
|
|
|
def sp_conv_forward(self, x: torch.Tensor): |
|
x = super(type(self), self).forward(x) |
|
x *= _get_active_ex_or_ii( |
|
H=x.shape[2], W=x.shape[3], returning_active_ex=True |
|
) # (BCHW) *= (B1HW), mask the output of conv |
|
return x |
|
|
|
|
|
def sp_bn_forward(self, x: torch.Tensor): |
|
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False) |
|
|
|
bhwc = x.permute(0, 2, 3, 1) |
|
nc = bhwc[ |
|
ii |
|
] # select the features on non-masked positions to form a flatten feature `nc` |
|
nc = super(type(self), self).forward( |
|
nc |
|
) # use BN1d to normalize this flatten feature `nc` |
|
|
|
bchw = torch.zeros_like(bhwc) |
|
bchw[ii] = nc |
|
bchw = bchw.permute(0, 3, 1, 2) |
|
return bchw |
|
|
|
|
|
class SparseUpsample(nn.Upsample): |
|
forward = sp_conv_forward |
|
|
|
|
|
class SparseConv2d(nn.Conv2d): |
|
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details |
|
|
|
|
|
class SparseMaxPooling(nn.MaxPool2d): |
|
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details |
|
|
|
|
|
class SparseAvgPooling(nn.AvgPool2d): |
|
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details |
|
|
|
|
|
class SparseBatchNorm2d(nn.BatchNorm1d): |
|
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details |
|
|
|
|
|
class SparseSyncBatchNorm2d(nn.SyncBatchNorm): |
|
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details |
|
|
|
|
|
class SparseConvNeXtLayerNorm(nn.LayerNorm): |
|
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. |
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
|
with shape (batch_size, channels, height, width). |
|
""" |
|
|
|
def __init__( |
|
self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True |
|
): |
|
if data_format not in ["channels_last", "channels_first"]: |
|
raise NotImplementedError |
|
super().__init__(normalized_shape, eps, elementwise_affine=True) |
|
self.data_format = data_format |
|
self.sparse = sparse |
|
|
|
def forward(self, x): |
|
if x.ndim == 4: # BHWC or BCHW |
|
if self.data_format == "channels_last": # BHWC |
|
if self.sparse: |
|
ii = _get_active_ex_or_ii( |
|
H=x.shape[1], W=x.shape[2], returning_active_ex=False |
|
) |
|
nc = x[ii] |
|
nc = super(SparseConvNeXtLayerNorm, self).forward(nc) |
|
|
|
x = torch.zeros_like(x) |
|
x[ii] = nc |
|
return x |
|
else: |
|
return super(SparseConvNeXtLayerNorm, self).forward(x) |
|
else: # channels_first, BCHW |
|
if self.sparse: |
|
ii = _get_active_ex_or_ii( |
|
H=x.shape[2], W=x.shape[3], returning_active_ex=False |
|
) |
|
bhwc = x.permute(0, 2, 3, 1) |
|
nc = bhwc[ii] |
|
nc = super(SparseConvNeXtLayerNorm, self).forward(nc) |
|
|
|
x = torch.zeros_like(bhwc) |
|
x[ii] = nc |
|
return x.permute(0, 3, 1, 2) |
|
else: |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
else: # BLC or BC |
|
if self.sparse: |
|
raise NotImplementedError |
|
else: |
|
return super(SparseConvNeXtLayerNorm, self).forward(x) |
|
|
|
def __repr__(self): |
|
return ( |
|
super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] |
|
+ f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})' |
|
) |
|
|
|
|
|
class SparseGlobalAveragePooling(nn.Module): |
|
def forward(self, x): # shape: BCHW |
|
B, C, H, W = x.shape |
|
unmasked_positions = _get_active_ex_or_ii( |
|
H=H, W=W, returning_active_ex=True |
|
) # shape: B1HW |
|
mean = (x * unmasked_positions).sum( |
|
dim=(2, 3), keepdims=True |
|
) / unmasked_positions.sum(dim=(2, 3), keepdims=True) |
|
return mean # shape: BC11 |
|
|
|
|
|
class SparseConvNeXtBlock(nn.Module): |
|
r"""ConvNeXt Block. There are two equivalent implementations: |
|
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) |
|
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back |
|
We use (2) as we find it slightly faster in PyTorch |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
drop_path (float): Stochastic depth rate. Default: 0.0 |
|
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
|
""" |
|
|
|
def __init__( |
|
self, dim, drop_path=0.0, layer_scale_init_value=1e-6, sparse=True, ks=7 |
|
): |
|
super().__init__() |
|
self.dwconv = nn.Conv2d( |
|
dim, dim, kernel_size=ks, padding=ks // 2, groups=dim |
|
) # depthwise conv |
|
self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse) |
|
self.pwconv1 = nn.Linear( |
|
dim, 4 * dim |
|
) # pointwise/1x1 convs, implemented with linear layers |
|
self.act = nn.GELU() |
|
self.pwconv2 = nn.Linear(4 * dim, dim) |
|
self.gamma = ( |
|
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) |
|
if layer_scale_init_value > 0 |
|
else None |
|
) |
|
self.drop_path: nn.Module = ( |
|
DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
) |
|
self.sparse = sparse |
|
|
|
def forward(self, x): |
|
input = x |
|
x = self.dwconv(x) |
|
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) |
|
x = self.norm(x) |
|
x = self.pwconv1(x) |
|
x = self.act( |
|
x |
|
) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`) |
|
x = self.pwconv2(x) |
|
if self.gamma is not None: |
|
x = self.gamma * x |
|
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) |
|
|
|
if self.sparse: |
|
x *= _get_active_ex_or_ii( |
|
H=x.shape[2], W=x.shape[3], returning_active_ex=True |
|
) |
|
|
|
x = input + self.drop_path(x) |
|
return x |
|
|
|
def __repr__(self): |
|
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f", sp={self.sparse})" |
|
|
|
|
|
class SparseEncoder(nn.Module): |
|
def __init__(self, cnn, input_size, sbn=False, verbose=False): |
|
super(SparseEncoder, self).__init__() |
|
self.sp_cnn = SparseEncoder.dense_model_to_sparse( |
|
m=cnn, verbose=verbose, sbn=sbn |
|
) |
|
self.input_size, self.downsample_raito, self.enc_feat_map_chs = ( |
|
input_size, |
|
cnn.get_downsample_ratio(), |
|
cnn.get_feature_map_channels(), |
|
) |
|
|
|
@staticmethod |
|
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False): |
|
oup = m |
|
if isinstance(m, nn.Conv2d): |
|
m: nn.Conv2d |
|
bias = m.bias is not None |
|
oup = SparseConv2d( |
|
m.in_channels, |
|
m.out_channels, |
|
kernel_size=m.kernel_size, |
|
stride=m.stride, |
|
padding=m.padding, |
|
dilation=m.dilation, |
|
groups=m.groups, |
|
bias=bias, |
|
padding_mode=m.padding_mode, |
|
) |
|
oup.weight.data.copy_(m.weight.data) |
|
if bias: |
|
oup.bias.data.copy_(m.bias.data) |
|
elif isinstance(m, nn.Upsample): |
|
oup = SparseUpsample( |
|
m.size, |
|
m.scale_factor, |
|
m.mode, |
|
m.align_corners, |
|
m.recompute_scale_factor, |
|
) |
|
|
|
elif isinstance(m, nn.MaxPool2d): |
|
m: nn.MaxPool2d |
|
oup = SparseMaxPooling( |
|
m.kernel_size, |
|
stride=m.stride, |
|
padding=m.padding, |
|
dilation=m.dilation, |
|
return_indices=m.return_indices, |
|
ceil_mode=m.ceil_mode, |
|
) |
|
elif isinstance(m, nn.AvgPool2d): |
|
m: nn.AvgPool2d |
|
oup = SparseAvgPooling( |
|
m.kernel_size, |
|
m.stride, |
|
m.padding, |
|
ceil_mode=m.ceil_mode, |
|
count_include_pad=m.count_include_pad, |
|
divisor_override=m.divisor_override, |
|
) |
|
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): |
|
m: nn.BatchNorm2d |
|
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)( |
|
m.weight.shape[0], |
|
eps=m.eps, |
|
momentum=m.momentum, |
|
affine=m.affine, |
|
track_running_stats=m.track_running_stats, |
|
) |
|
oup.weight.data.copy_(m.weight.data) |
|
oup.bias.data.copy_(m.bias.data) |
|
oup.running_mean.data.copy_(m.running_mean.data) |
|
oup.running_var.data.copy_(m.running_var.data) |
|
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data) |
|
if hasattr(m, "qconfig"): |
|
oup.qconfig = m.qconfig |
|
elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm): |
|
m: nn.LayerNorm |
|
oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps) |
|
oup.weight.data.copy_(m.weight.data) |
|
oup.bias.data.copy_(m.bias.data) |
|
elif isinstance(m, (nn.Conv1d,)): |
|
raise NotImplementedError |
|
|
|
if oup is not m: |
|
oup_member = dir(oup) |
|
oup_member = set( |
|
[x for x in oup_member if not x.startswith("__") and not callable(x)] |
|
) |
|
m_member = dir(m) |
|
m_member=set([x for x in m_member if not x.startswith("__") and not callable(x)]) |
|
for x in m_member-oup_member: |
|
setattr(oup, x, getattr(m, x)) |
|
|
|
for name, child in m.named_children(): |
|
oup.add_module( |
|
name, |
|
SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn), |
|
) |
|
del m |
|
return oup |
|
|
|
def forward(self, x): |
|
return self.sp_cnn(x, hierarchical=True) |
|
|
|
def __repr__(self): |
|
return self.sp_cnn.__repr__()
|
|
|