# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings import math from functools import partial import paddle as pd import paddle.nn as nn import paddle.nn.functional as F from .layers.pd_timm import DropPath, to_2tuple def calc_product(*args): if len(args) < 1: raise ValueError ret = args[0] for arg in args[1:]: ret *= arg return ret class ConvBlock(pd.nn.Layer): def __init__(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None): super(ConvBlock, self).__init__() self.conv = pd.nn.Conv2D( input_size, output_size, kernel_size, stride, padding, bias_attr=bias) self.norm = norm if self.norm == 'batch': self.bn = pd.nn.BatchNorm2D(output_size) elif self.norm == 'instance': self.bn = pd.nn.InstanceNorm2D(output_size) self.activation = activation if self.activation == 'relu': self.act = pd.nn.ReLU(True) elif self.activation == 'prelu': self.act = pd.nn.PReLU() elif self.activation == 'lrelu': self.act = pd.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = pd.nn.Tanh() elif self.activation == 'sigmoid': self.act = pd.nn.Sigmoid() def forward(self, x): if self.norm is not None: out = self.bn(self.conv(x)) else: out = self.conv(x) if self.activation != 'no': return self.act(out) else: return out class DeconvBlock(pd.nn.Layer): def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None): super(DeconvBlock, self).__init__() self.deconv = pd.nn.Conv2DTranspose( input_size, output_size, kernel_size, stride, padding, bias_attr=bias) self.norm = norm if self.norm == 'batch': self.bn = pd.nn.BatchNorm2D(output_size) elif self.norm == 'instance': self.bn = pd.nn.InstanceNorm2D(output_size) self.activation = activation if self.activation == 'relu': self.act = pd.nn.ReLU(True) elif self.activation == 'prelu': self.act = pd.nn.PReLU() elif self.activation == 'lrelu': self.act = pd.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = pd.nn.Tanh() elif self.activation == 'sigmoid': self.act = pd.nn.Sigmoid() def forward(self, x): if self.norm is not None: out = self.bn(self.deconv(x)) else: out = self.deconv(x) if self.activation is not None: return self.act(out) else: return out class ConvLayer(nn.Layer): def __init__(self, in_channels, out_channels, kernel_size, stride, padding): super(ConvLayer, self).__init__() self.conv2d = nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding) def forward(self, x): out = self.conv2d(x) return out class UpsampleConvLayer(pd.nn.Layer): def __init__(self, in_channels, out_channels, kernel_size, stride): super(UpsampleConvLayer, self).__init__() self.conv2d = nn.Conv2DTranspose( in_channels, out_channels, kernel_size, stride=stride, padding=1) def forward(self, x): out = self.conv2d(x) return out class ResidualBlock(pd.nn.Layer): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = ConvLayer( channels, channels, kernel_size=3, stride=1, padding=1) self.conv2 = ConvLayer( channels, channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x): residual = x out = self.relu(self.conv1(x)) out = self.conv2(out) * 0.1 out = pd.add(out, residual) return out class ChangeFormer(nn.Layer): """ The ChangeFormer implementation based on PaddlePaddle. The original article refers to Wele Gedara Chaminda Bandara, Vishal M. Patel., "A TRANSFORMER-BASED SIAMESE NETWORK FOR CHANGE DETECTION" (https://arxiv.org/pdf/2201.01293.pdf). Args: in_channels (int): Number of bands of the input images. Default: 3. num_classes (int): Number of target classes. Default: 2. decoder_softmax (bool, optional): Use softmax after decode or not. Default: False. embed_dim (int, optional): Embedding dimension of each decoder head. Default: 256. """ def __init__(self, in_channels=3, num_classes=2, decoder_softmax=False, embed_dim=256): super(ChangeFormer, self).__init__() # Transformer Encoder self.embed_dims = [64, 128, 320, 512] self.depths = [3, 3, 4, 3] self.embedding_dim = embed_dim self.drop_rate = 0.1 self.attn_drop = 0.1 self.drop_path_rate = 0.1 self.Tenc_x2 = EncoderTransformer_v3( img_size=256, patch_size=7, in_chans=in_channels, num_classes=num_classes, embed_dims=self.embed_dims, num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=self.drop_rate, attn_drop_rate=self.attn_drop, drop_path_rate=self.drop_path_rate, norm_layer=partial( nn.LayerNorm, epsilon=1e-6), depths=self.depths, sr_ratios=[8, 4, 2, 1]) # Transformer Decoder self.TDec_x2 = DecoderTransformer_v3( input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=False, in_channels=self.embed_dims, embedding_dim=self.embedding_dim, output_nc=num_classes, decoder_softmax=decoder_softmax, feature_strides=[2, 4, 8, 16]) def forward(self, x1, x2): [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)] cp = self.TDec_x2(fx1, fx2) return [cp] # Transormer Ecoder with x2, x4, x8, x16 scales class EncoderTransformer_v3(nn.Layer): def __init__(self, img_size=256, patch_size=3, in_chans=3, num_classes=2, embed_dims=[32, 64, 128, 256], num_heads=[2, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 3, 6, 18], sr_ratios=[8, 4, 2, 1]): super().__init__() self.num_classes = num_classes self.depths = depths self.embed_dims = embed_dims # Patch embedding definitions self.patch_embed1 = OverlapPatchEmbed( img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0]) self.patch_embed2 = OverlapPatchEmbed( img_size=img_size // 4, patch_size=patch_size, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.patch_embed3 = OverlapPatchEmbed( img_size=img_size // 8, patch_size=patch_size, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.patch_embed4 = OverlapPatchEmbed( img_size=img_size // 16, patch_size=patch_size, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) # Stage-1 (x1/4 scale) dpr = [x.item() for x in pd.linspace(0, drop_path_rate, sum(depths))] cur = 0 self.block1 = nn.LayerList([ Block( dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0]) for i in range(depths[0]) ]) self.norm1 = norm_layer(embed_dims[0]) # Stage-2 (x1/8 scale) cur += depths[0] self.block2 = nn.LayerList([ Block( dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1]) for i in range(depths[1]) ]) self.norm2 = norm_layer(embed_dims[1]) # Stage-3 (x1/16 scale) cur += depths[1] self.block3 = nn.LayerList([ Block( dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2]) for i in range(depths[2]) ]) self.norm3 = norm_layer(embed_dims[2]) # Stage-4 (x1/32 scale) cur += depths[2] self.block4 = nn.LayerList([ Block( dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3]) for i in range(depths[3]) ]) self.norm4 = norm_layer(embed_dims[3]) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_op = nn.initializer.TruncatedNormal(std=.02) trunc_normal_op(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) elif isinstance(m, nn.LayerNorm): init_bias = nn.initializer.Constant(0) init_bias(m.bias) init_weight = nn.initializer.Constant(1.0) init_weight(m.weight) elif isinstance(m, nn.Conv2D): fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels fan_out //= m._groups init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out)) init_weight(m.weight) if m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) def reset_drop_path(self, drop_path_rate): dpr = [ x.item() for x in pd.linspace(0, drop_path_rate, sum(self.depths)) ] cur = 0 for i in range(self.depths[0]): self.block1[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[0] for i in range(self.depths[1]): self.block2[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[1] for i in range(self.depths[2]): self.block3[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[2] for i in range(self.depths[3]): self.block4[i].drop_path.drop_prob = dpr[cur + i] def forward_features(self, x): B = x.shape[0] outs = [] # Stage 1 x1, H1, W1 = self.patch_embed1(x) for i, blk in enumerate(self.block1): x1 = blk(x1, H1, W1) x1 = self.norm1(x1) x1 = x1.reshape( [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose( [0, 3, 1, 2]) outs.append(x1) # Stage 2 x1, H1, W1 = self.patch_embed2(x1) for i, blk in enumerate(self.block2): x1 = blk(x1, H1, W1) x1 = self.norm2(x1) x1 = x1.reshape( [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose( [0, 3, 1, 2]) outs.append(x1) # Stage 3 x1, H1, W1 = self.patch_embed3(x1) for i, blk in enumerate(self.block3): x1 = blk(x1, H1, W1) x1 = self.norm3(x1) x1 = x1.reshape( [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose( [0, 3, 1, 2]) outs.append(x1) # Stage 4 x1, H1, W1 = self.patch_embed4(x1) for i, blk in enumerate(self.block4): x1 = blk(x1, H1, W1) x1 = self.norm4(x1) x1 = x1.reshape( [B, H1, W1, calc_product(*x1.shape[1:]) // (H1 * W1)]).transpose( [0, 3, 1, 2]) outs.append(x1) return outs def forward(self, x): x = self.forward_features(x) return x class DecoderTransformer_v3(nn.Layer): """ Transformer Decoder """ def __init__(self, input_transform='multiple_select', in_index=[0, 1, 2, 3], align_corners=True, in_channels=[32, 64, 128, 256], embedding_dim=64, output_nc=2, decoder_softmax=False, feature_strides=[2, 4, 8, 16]): super(DecoderTransformer_v3, self).__init__() assert len(feature_strides) == len(in_channels) assert min(feature_strides) == feature_strides[0] # Settings self.feature_strides = feature_strides self.input_transform = input_transform self.in_index = in_index self.align_corners = align_corners self.in_channels = in_channels self.embedding_dim = embedding_dim self.output_nc = output_nc c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels # MLP decoder heads self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) # Convolutional Difference Layers self.diff_c4 = conv_diff( in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) self.diff_c3 = conv_diff( in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) self.diff_c2 = conv_diff( in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) self.diff_c1 = conv_diff( in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) # Take outputs from middle of the encoder self.make_pred_c4 = make_prediction( in_channels=self.embedding_dim, out_channels=self.output_nc) self.make_pred_c3 = make_prediction( in_channels=self.embedding_dim, out_channels=self.output_nc) self.make_pred_c2 = make_prediction( in_channels=self.embedding_dim, out_channels=self.output_nc) self.make_pred_c1 = make_prediction( in_channels=self.embedding_dim, out_channels=self.output_nc) # Final linear fusion layer self.linear_fuse = nn.Sequential( nn.Conv2D( in_channels=self.embedding_dim * len(in_channels), out_channels=self.embedding_dim, kernel_size=1), nn.BatchNorm2D(self.embedding_dim)) # Final predction head self.convd2x = UpsampleConvLayer( self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim)) self.convd1x = UpsampleConvLayer( self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim)) self.change_probability = ConvLayer( self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1) # Final activation self.output_softmax = decoder_softmax self.active = nn.Sigmoid() def _transform_inputs(self, inputs): """ Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = pd.concat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs def forward(self, inputs1, inputs2): # Transforming encoder features (select layers) x_1 = self._transform_inputs(inputs1) # len=4, 1/2, 1/4, 1/8, 1/16 x_2 = self._transform_inputs(inputs2) # len=4, 1/2, 1/4, 1/8, 1/16 # img1 and img2 features c1_1, c2_1, c3_1, c4_1 = x_1 c1_2, c2_2, c3_2, c4_2 = x_2 ############## MLP decoder on C1-C4 ########### n, _, h, w = c4_1.shape outputs = [] # Stage 4: x1/32 scale _c4_1 = self.linear_c4(c4_1).transpose([0, 2, 1]) _c4_1 = _c4_1.reshape([ n, calc_product(*_c4_1.shape[1:]) // (c4_1.shape[2] * c4_1.shape[3]), c4_1.shape[2], c4_1.shape[3] ]) _c4_2 = self.linear_c4(c4_2).transpose([0, 2, 1]) _c4_2 = _c4_2.reshape([ n, calc_product(*_c4_2.shape[1:]) // (c4_2.shape[2] * c4_2.shape[3]), c4_2.shape[2], c4_2.shape[3] ]) _c4 = self.diff_c4(pd.concat((_c4_1, _c4_2), axis=1)) p_c4 = self.make_pred_c4(_c4) outputs.append(p_c4) _c4_up = resize( _c4, size=c1_2.shape[2:], mode='bilinear', align_corners=False) # Stage 3: x1/16 scale _c3_1 = self.linear_c3(c3_1).transpose([0, 2, 1]) _c3_1 = _c3_1.reshape([ n, calc_product(*_c3_1.shape[1:]) // (c3_1.shape[2] * c3_1.shape[3]), c3_1.shape[2], c3_1.shape[3] ]) _c3_2 = self.linear_c3(c3_2).transpose([0, 2, 1]) _c3_2 = _c3_2.reshape([ n, calc_product(*_c3_2.shape[1:]) // (c3_2.shape[2] * c3_2.shape[3]), c3_2.shape[2], c3_2.shape[3] ]) _c3 = self.diff_c3(pd.concat((_c3_1, _c3_2), axis=1)) + \ F.interpolate(_c4, scale_factor=2, mode="bilinear") p_c3 = self.make_pred_c3(_c3) outputs.append(p_c3) _c3_up = resize( _c3, size=c1_2.shape[2:], mode='bilinear', align_corners=False) # Stage 2: x1/8 scale _c2_1 = self.linear_c2(c2_1).transpose([0, 2, 1]) _c2_1 = _c2_1.reshape([ n, calc_product(*_c2_1.shape[1:]) // (c2_1.shape[2] * c2_1.shape[3]), c2_1.shape[2], c2_1.shape[3] ]) _c2_2 = self.linear_c2(c2_2).transpose([0, 2, 1]) _c2_2 = _c2_2.reshape([ n, calc_product(*_c2_2.shape[1:]) // (c2_2.shape[2] * c2_2.shape[3]), c2_2.shape[2], c2_2.shape[3] ]) _c2 = self.diff_c2(pd.concat((_c2_1, _c2_2), axis=1)) + \ F.interpolate(_c3, scale_factor=2, mode="bilinear") p_c2 = self.make_pred_c2(_c2) outputs.append(p_c2) _c2_up = resize( _c2, size=c1_2.shape[2:], mode='bilinear', align_corners=False) # Stage 1: x1/4 scale _c1_1 = self.linear_c1(c1_1).transpose([0, 2, 1]) _c1_1 = _c1_1.reshape([ n, calc_product(*_c1_1.shape[1:]) // (c1_1.shape[2] * c1_1.shape[3]), c1_1.shape[2], c1_1.shape[3] ]) _c1_2 = self.linear_c1(c1_2).transpose([0, 2, 1]) _c1_2 = _c1_2.reshape([ n, calc_product(*_c1_2.shape[1:]) // (c1_2.shape[2] * c1_2.shape[3]), c1_2.shape[2], c1_2.shape[3] ]) _c1 = self.diff_c1(pd.concat((_c1_1, _c1_2), axis=1)) + \ F.interpolate(_c2, scale_factor=2, mode="bilinear") p_c1 = self.make_pred_c1(_c1) outputs.append(p_c1) # Linear Fusion of difference image from all scales _c = self.linear_fuse(pd.concat((_c4_up, _c3_up, _c2_up, _c1), axis=1)) # Upsampling x2 (x1/2 scale) x = self.convd2x(_c) # Residual block x = self.dense_2x(x) # Upsampling x2 (x1 scale) x = self.convd1x(x) # Residual block x = self.dense_1x(x) # Final prediction cp = self.change_probability(x) outputs.append(cp) if self.output_softmax: temp = outputs outputs = [] for pred in temp: outputs.append(self.active(pred)) return outputs[-1] class OverlapPatchEmbed(nn.Layer): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.H, self.W = img_size[0] // patch_size[0], img_size[ 1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2D( in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_op = nn.initializer.TruncatedNormal(std=.02) trunc_normal_op(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) elif isinstance(m, nn.LayerNorm): init_bias = nn.initializer.Constant(0) init_bias(m.bias) init_weight = nn.initializer.Constant(1.0) init_weight(m.weight) elif isinstance(m, nn.Conv2D): fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels fan_out //= m._groups init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out)) init_weight(m.weight) if m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) def forward(self, x): x = self.proj(x) _, _, H, W = x.shape x = x.flatten(2).transpose([0, 2, 1]) x = self.norm(x) return x, H, W def resize(input, size=None, scale_factor=None, mode='nearest', align_corners=None, warning=True): if warning: if size is not None and align_corners: input_h, input_w = tuple(int(x) for x in input.shape[2:]) output_h, output_w = tuple(int(x) for x in size) if output_h > input_h or output_w > output_h: if ((output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) and (output_h - 1) % (input_h - 1) and (output_w - 1) % (input_w - 1)): warnings.warn( f'When align_corners={align_corners}, ' 'the output would more aligned if ' f'input size {(input_h, input_w)} is `x+1` and ' f'out size {(output_h, output_w)} is `nx+1`') return F.interpolate(input, size, scale_factor, mode, align_corners) class Mlp(nn.Layer): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.dwconv = DWConv(hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_op = nn.initializer.TruncatedNormal(std=.02) trunc_normal_op(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) elif isinstance(m, nn.LayerNorm): init_bias = nn.initializer.Constant(0) init_bias(m.bias) init_weight = nn.initializer.Constant(1.0) init_weight(m.weight) elif isinstance(m, nn.Conv2D): fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels fan_out //= m._groups init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out)) init_weight(m.weight) if m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) def forward(self, x, H, W): x = self.fc1(x) x = self.dwconv(x, H, W) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Layer): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.q = nn.Linear(dim, dim, bias_attr=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_op = nn.initializer.TruncatedNormal(std=.02) trunc_normal_op(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) elif isinstance(m, nn.LayerNorm): init_bias = nn.initializer.Constant(0) init_bias(m.bias) init_weight = nn.initializer.Constant(1.0) init_weight(m.weight) elif isinstance(m, nn.Conv2D): fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels fan_out //= m._groups init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out)) init_weight(m.weight) if m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) def forward(self, x, H, W): B, N, C = x.shape q = self.q(x).reshape([B, N, self.num_heads, C // self.num_heads]).transpose([0, 2, 1, 3]) if self.sr_ratio > 1: x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W]) x_ = self.sr(x_) x_ = x_.reshape([B, C, calc_product(*x_.shape[2:])]).transpose( [0, 2, 1]) x_ = self.norm(x_) kv = self.kv(x_) kv = kv.reshape([ B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads, C // self.num_heads ]).transpose([2, 0, 3, 1, 4]) else: kv = self.kv(x) kv = kv.reshape([ B, calc_product(*kv.shape[1:]) // (2 * C), 2, self.num_heads, C // self.num_heads ]).transpose([2, 0, 3, 1, 4]) k, v = kv[0], kv[1] attn = (q @k.transpose([0, 1, 3, 2])) * self.scale attn = F.softmax(attn, axis=-1) attn = self.attn_drop(attn) x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Layer): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity( ) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_op = nn.initializer.TruncatedNormal(std=.02) trunc_normal_op(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) elif isinstance(m, nn.LayerNorm): init_bias = nn.initializer.Constant(0) init_bias(m.bias) init_weight = nn.initializer.Constant(1.0) init_weight(m.weight) elif isinstance(m, nn.Conv2D): fan_out = m._kernel_size[0] * m._kernel_size[1] * m._out_channels fan_out //= m._groups init_weight = nn.initializer.Normal(0, math.sqrt(2.0 / fan_out)) init_weight(m.weight) if m.bias is not None: init_bias = nn.initializer.Constant(0) init_bias(m.bias) def forward(self, x, H, W): x = x + self.drop_path(self.attn(self.norm1(x), H, W)) x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) return x class DWConv(nn.Layer): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, bias_attr=True, groups=dim) def forward(self, x, H, W): B, N, C = x.shape x = x.transpose([0, 2, 1]).reshape([B, C, H, W]) x = self.dwconv(x) x = x.flatten(2).transpose([0, 2, 1]) return x # Transformer Decoder class MLP(nn.Layer): """ Linear Embedding """ def __init__(self, input_dim=2048, embed_dim=768): super().__init__() self.proj = nn.Linear(input_dim, embed_dim) def forward(self, x): x = x.flatten(2).transpose([0, 2, 1]) x = self.proj(x) return x # Difference Layer def conv_diff(in_channels, out_channels): return nn.Sequential( nn.Conv2D( in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2D(out_channels), nn.Conv2D( out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()) # Intermediate prediction Layer def make_prediction(in_channels, out_channels): return nn.Sequential( nn.Conv2D( in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2D(out_channels), nn.Conv2D( out_channels, out_channels, kernel_size=3, padding=1))