diff --git a/paddlers/models/cd/models/__init__.py b/paddlers/models/cd/models/__init__.py index 08da15f..7219a26 100644 --- a/paddlers/models/cd/models/__init__.py +++ b/paddlers/models/cd/models/__init__.py @@ -12,4 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .cdnet import CDNet \ No newline at end of file +from .bit import BIT +from .cdnet import CDNet +from .dsifn import DSIFN +from .stanet import STANet +from .snunet import SNUNet +from .dsamnet import DSAMNet +from .unet_ef import UNetEarlyFusion +from .unet_siamconc import UNetSiamConc +from .unet_siamdiff import UNetSiamDiff \ No newline at end of file diff --git a/paddlers/models/cd/models/backbones/__init__.py b/paddlers/models/cd/models/backbones/__init__.py new file mode 100644 index 0000000..eeae9aa --- /dev/null +++ b/paddlers/models/cd/models/backbones/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/paddlers/models/cd/models/backbones/resnet.py b/paddlers/models/cd/models/backbones/resnet.py new file mode 100644 index 0000000..b5bb823 --- /dev/null +++ b/paddlers/models/cd/models/backbones/resnet.py @@ -0,0 +1,358 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py +## Original head information +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +from paddle.utils.download import get_weights_path_from_url + +__all__ = [] + +model_urls = { + 'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', + 'cf548f46534aa3560945be4b95cd11c4'), + 'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams', + '8d2275cf8706028345f78ac0e1d31969'), + 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', + 'ca6f485ee1ab0492d38f323885b0ad80'), + 'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams', + '02f35f034ca3858e1e54d4036443c92d'), + 'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', + '7ad16a2f1e7333859ff986138630fd7a'), +} + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + + self.conv1 = nn.Conv2D( + inplanes, planes, 3, padding=1, stride=stride, bias_attr=False) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class BottleneckBlock(nn.Layer): + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BottleneckBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + width = int(planes * (base_width / 64.)) * groups + + self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False) + self.bn1 = norm_layer(width) + + self.conv2 = nn.Conv2D( + width, + width, + 3, + padding=dilation, + stride=stride, + groups=groups, + dilation=dilation, + bias_attr=False) + self.bn2 = norm_layer(width) + + self.conv3 = nn.Conv2D( + width, planes * self.expansion, 1, bias_attr=False) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Layer): + """ResNet model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + Block (BasicBlock|BottleneckBlock): block module of model. + depth (int): layers of resnet, default: 50. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. + Examples: + .. code-block:: python + from paddle.vision.models import ResNet + from paddle.vision.models.resnet import BottleneckBlock, BasicBlock + resnet50 = ResNet(BottleneckBlock, 50) + resnet18 = ResNet(BasicBlock, 18) + """ + + def __init__(self, block, depth, num_classes=1000, with_pool=True, strides=(1,1,2,2,2), norm_layer=None): + super(ResNet, self).__init__() + layer_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3] + } + layers = layer_cfg[depth] + self.num_classes = num_classes + self.with_pool = with_pool + self._norm_layer = nn.BatchNorm2D if norm_layer is None else norm_layer + + self.inplanes = 64 + self.dilation = 1 + + self.conv1 = nn.Conv2D( + 3, + self.inplanes, + kernel_size=7, + stride=strides[0], + padding=3, + bias_attr=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[1]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4]) + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + + if num_classes > 0: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D( + self.inplanes, + planes * block.expansion, + 1, + stride=stride, + bias_attr=False), + norm_layer(planes * block.expansion), ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, 1, 64, + previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.with_pool: + x = self.avgpool(x) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x) + + return x + + +def _resnet(arch, Block, depth, pretrained, **kwargs): + model = ResNet(Block, depth, **kwargs) + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weight_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) + + param = paddle.load(weight_path) + model.set_dict(param) + + return model + + +def resnet18(pretrained=False, **kwargs): + """ResNet 18-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + Examples: + .. code-block:: python + from paddle.vision.models import resnet18 + # build model + model = resnet18() + # build model and load imagenet pretrained weight + # model = resnet18(pretrained=True) + """ + return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) + + +def resnet34(pretrained=False, **kwargs): + """ResNet 34-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + from paddle.vision.models import resnet34 + # build model + model = resnet34() + # build model and load imagenet pretrained weight + # model = resnet34(pretrained=True) + """ + return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) + + +def resnet50(pretrained=False, **kwargs): + """ResNet 50-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + Examples: + .. code-block:: python + from paddle.vision.models import resnet50 + # build model + model = resnet50() + # build model and load imagenet pretrained weight + # model = resnet50(pretrained=True) + """ + return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnet101(pretrained=False, **kwargs): + """ResNet 101-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + Examples: + .. code-block:: python + from paddle.vision.models import resnet101 + # build model + model = resnet101() + # build model and load imagenet pretrained weight + # model = resnet101(pretrained=True) + """ + return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) + + +def resnet152(pretrained=False, **kwargs): + """ResNet 152-layer model + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + Examples: + .. code-block:: python + from paddle.vision.models import resnet152 + # build model + model = resnet152() + # build model and load imagenet pretrained weight + # model = resnet152(pretrained=True) + """ + return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) \ No newline at end of file diff --git a/paddlers/models/cd/models/bit.py b/paddlers/models/cd/models/bit.py new file mode 100644 index 0000000..5a66f9f --- /dev/null +++ b/paddlers/models/cd/models/bit.py @@ -0,0 +1,395 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import Normal + + +from .backbones import resnet +from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity +from .param_init import KaimingInitMixin + + +class BIT(nn.Layer): + """ + The BIT implementation based on PaddlePaddle. + + The original article refers to + H. Chen, et al., "Remote Sensing Image Change Detection With Transformers" + (https://arxiv.org/abs/2103.00208). + + This implementation adopts pretrained encoders, as opposed to the original work where weights are randomly initialized. + + Args: + in_channels (int): The number of bands of the input images. + num_classes (int): The number of target classes. + backbone (str, optional): The ResNet architecture that is used as the backbone. Currently, only 'resnet18' and + 'resnet34' are supported. Default: 'resnet18'. + n_stages (int, optional): The number of ResNet stages used in the backbone, which should be a value in {3,4,5}. + Default: 4. + use_tokenizer (bool, optional): Use a tokenizer or not. Default: True. + token_len (int, optional): The length of input tokens. Default: 4. + pool_mode (str, optional): The pooling strategy to obtain input tokens when `use_tokenizer` is set to False. 'max' + for global max pooling and 'avg' for global average pooling. Default: 'max'. + pool_size (int, optional): The height and width of the pooled feature maps when `use_tokenizer` is set to False. + Default: 2. + enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the + encoder. Default: True. + enc_depth (int, optional): The number of attention blocks used in the encoder. Default: 1 + enc_head_dim (int, optional): The embedding dimension of each encoder head. Default: 64. + dec_depth (int, optional): The number of attention blocks used in the decoder. Default: 8. + dec_head_dim (int, optional): The embedding dimension of each decoder head. Default: 8. + + Raises: + ValueError: When an unsupported backbone type is specified, or the number of backbone stages is not 3, 4, or 5. + """ + + def __init__( + self, in_channels, num_classes, + backbone='resnet18', n_stages=4, + use_tokenizer=True, token_len=4, + pool_mode='max', pool_size=2, + enc_with_pos=True, + enc_depth=1, enc_head_dim=64, + dec_depth=8, dec_head_dim=8, + **backbone_kwargs + ): + super().__init__() + + # TODO: reduce hard-coded parameters + DIM = 32 + MLP_DIM = 2*DIM + EBD_DIM = DIM + + self.backbone = Backbone(in_channels, EBD_DIM, arch=backbone, n_stages=n_stages, **backbone_kwargs) + + self.use_tokenizer = use_tokenizer + if not use_tokenizer: + # If a tokenzier is not to be used,then downsample the feature maps. + self.pool_size = pool_size + self.pool_mode = pool_mode + self.token_len = pool_size * pool_size + else: + self.conv_att = Conv1x1(32, token_len, bias=False) + self.token_len = token_len + + self.enc_with_pos = enc_with_pos + if enc_with_pos: + self.enc_pos_embedding = self.create_parameter( + shape=(1,self.token_len*2,EBD_DIM), + default_initializer=Normal() + ) + + self.enc_depth = enc_depth + self.dec_depth = dec_depth + self.enc_head_dim = enc_head_dim + self.dec_head_dim = dec_head_dim + + self.encoder = TransformerEncoder( + dim=DIM, + depth=enc_depth, + n_heads=8, + head_dim=enc_head_dim, + mlp_dim=MLP_DIM, + dropout_rate=0. + ) + self.decoder = TransformerDecoder( + dim=DIM, + depth=dec_depth, + n_heads=8, + head_dim=dec_head_dim, + mlp_dim=MLP_DIM, + dropout_rate=0., + apply_softmax=True + ) + + self.upsample = nn.Upsample(scale_factor=4, mode='bilinear') + self.conv_out = nn.Sequential( + Conv3x3(EBD_DIM, EBD_DIM, norm=True, act=True), + Conv3x3(EBD_DIM, num_classes) + ) + + def _get_semantic_tokens(self, x): + b, c = paddle.shape(x)[:2] + att_map = self.conv_att(x) + att_map = att_map.reshape((b,self.token_len,1,-1)) + att_map = F.softmax(att_map, axis=-1) + x = x.reshape((b,1,c,-1)) + tokens = (x*att_map).sum(-1) + return tokens + + def _get_reshaped_tokens(self, x): + if self.pool_mode == 'max': + x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size)) + elif self.pool_mode == 'avg': + x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size)) + else: + x = x + tokens = x.transpose((0,2,3,1)).flatten(1,2) + return tokens + + def encode(self, x): + if self.enc_with_pos: + x += self.enc_pos_embedding + x = self.encoder(x) + return x + + def decode(self, x, m): + b, c, h, w = paddle.shape(x) + x = x.transpose((0,2,3,1)).flatten(1,2) + x = self.decoder(x, m) + x = x.transpose((0,2,1)).reshape((b,c,h,w)) + return x + + def forward(self, t1, t2): + # Extract features via shared backbone. + x1 = self.backbone(t1) + x2 = self.backbone(t2) + + # Tokenization + if self.use_tokenizer: + token1 = self._get_semantic_tokens(x1) + token2 = self._get_semantic_tokens(x2) + else: + token1 = self._get_reshaped_tokens(x1) + token2 = self._get_reshaped_tokens(x2) + + # Transformer encoder forward + token = paddle.concat([token1, token2], axis=1) + token = self.encode(token) + token1, token2 = paddle.chunk(token, 2, axis=1) + + # Transformer decoder forward + y1 = self.decode(x1, token1) + y2 = self.decode(x2, token2) + + # Feature differencing + y = paddle.abs(y1 - y2) + y = self.upsample(y) + + # Classifier forward + pred = self.conv_out(y) + return pred, + + def init_weight(self): + # Use the default initialization method. + pass + + +class Residual(nn.Layer): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + + +class Residual2(nn.Layer): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x1, x2, **kwargs): + return self.fn(x1, x2, **kwargs) + x1 + + +class PreNorm(nn.Layer): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class PreNorm2(nn.Layer): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x1, x2, **kwargs): + return self.fn(self.norm(x1), self.norm(x2), **kwargs) + + +class FeedForward(nn.Sequential): + def __init__(self, dim, hidden_dim, dropout_rate=0.): + super().__init__( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout_rate), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout_rate) + ) + + +class CrossAttention(nn.Layer): + def __init__(self, dim, n_heads=8, head_dim=64, dropout_rate=0., apply_softmax=True): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.scale = dim ** -0.5 + + self.apply_softmax = apply_softmax + + self.fc_q = nn.Linear(dim, inner_dim, bias_attr=False) + self.fc_k = nn.Linear(dim, inner_dim, bias_attr=False) + self.fc_v = nn.Linear(dim, inner_dim, bias_attr=False) + + self.fc_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout_rate) + ) + + def forward(self, x, ref): + b, n = paddle.shape(x)[:2] + h = self.n_heads + + q = self.fc_q(x) + k = self.fc_k(ref) + v = self.fc_v(ref) + + q = q.reshape((b,n,h,-1)).transpose((0,2,1,3)) + k = k.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3)) + v = v.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3)) + + mult = paddle.matmul(q, k, transpose_y=True) * self.scale + + if self.apply_softmax: + mult = F.softmax(mult, axis=-1) + + out = paddle.matmul(mult, v) + out = out.transpose((0,2,1,3)).flatten(2) + return self.fc_out(out) + + +class SelfAttention(CrossAttention): + def forward(self, x): + return super().forward(x, x) + + +class TransformerEncoder(nn.Layer): + def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate): + super().__init__() + self.layers = nn.LayerList([]) + for _ in range(depth): + self.layers.append(nn.LayerList([ + Residual(PreNorm(dim, SelfAttention(dim, n_heads, head_dim, dropout_rate))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))) + ])) + + def forward(self, x): + for att, ff in self.layers: + x = att(x) + x = ff(x) + return x + + +class TransformerDecoder(nn.Layer): + def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate, apply_softmax=True): + super().__init__() + self.layers = nn.LayerList([]) + for _ in range(depth): + self.layers.append(nn.LayerList([ + Residual2(PreNorm2(dim, CrossAttention(dim, n_heads, head_dim, dropout_rate, apply_softmax))), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))) + ])) + + def forward(self, x, m): + for att, ff in self.layers: + x = att(x, m) + x = ff(x) + return x + + +class Backbone(nn.Layer, KaimingInitMixin): + def __init__( + self, + in_ch, out_ch=32, + arch='resnet18', + pretrained=True, + n_stages=5 + ): + super().__init__() + + expand = 1 + strides = (2,1,2,1,1) + if arch == 'resnet18': + self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) + elif arch == 'resnet34': + self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) + else: + raise ValueError + + self.n_stages = n_stages + + if self.n_stages == 5: + itm_ch = 512 * expand + elif self.n_stages == 4: + itm_ch = 256 * expand + elif self.n_stages == 3: + itm_ch = 128 * expand + else: + raise ValueError + + self.upsample = nn.Upsample(scale_factor=2) + self.conv_out = Conv3x3(itm_ch, out_ch) + + self._trim_resnet() + + if in_ch != 3: + self.resnet.conv1 = nn.Conv2D( + in_ch, + 64, + kernel_size=7, + stride=2, + padding=3, + bias_attr=False + ) + + if not pretrained: + self.init_weight() + + def forward(self, x): + y = self.resnet.conv1(x) + y = self.resnet.bn1(y) + y = self.resnet.relu(y) + y = self.resnet.maxpool(y) + + y = self.resnet.layer1(y) + y = self.resnet.layer2(y) + y = self.resnet.layer3(y) + y = self.resnet.layer4(y) + + y = self.upsample(y) + + return self.conv_out(y) + + def _trim_resnet(self): + if self.n_stages > 5: + raise ValueError + + if self.n_stages < 5: + self.resnet.layer4 = Identity() + + if self.n_stages <= 3: + self.resnet.layer3 = Identity() + + self.resnet.avgpool = Identity() + self.resnet.fc = Identity() \ No newline at end of file diff --git a/paddlers/models/cd/models/dsamnet.py b/paddlers/models/cd/models/dsamnet.py new file mode 100644 index 0000000..3bcf2b4 --- /dev/null +++ b/paddlers/models/cd/models/dsamnet.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +from .layers import make_norm, Conv3x3, CBAM +from .stanet import Backbone, Decoder + + +class DSAMNet(nn.Layer): + """ + The DSAMNet implementation based on PaddlePaddle. + + The original article refers to + Q. Shi, et al., "A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing + Change Detection" + (https://ieeexplore.ieee.org/document/9467555). + + Note that this implementation differs from the original work in two aspects: + 1. We do not use multiple dilation rates in layer 4 of the ResNet backbone. + 2. A classification head is used in place of the original metric learning-based head to stablize the training process. + + Args: + in_channels (int): The number of bands of the input images. + num_classes (int): The number of target classes. + ca_ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8. + sa_kernel (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7. + """ + + def __init__(self, in_channels, num_classes, ca_ratio=8, sa_kernel=7): + super().__init__() + + WIDTH = 64 + + self.backbone = Backbone(in_ch=in_channels, arch='resnet18', strides=(1,1,2,2,1)) + self.decoder = Decoder(WIDTH) + + self.cbam1 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel) + self.cbam2 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel) + + self.dsl2 = DSLayer(64, num_classes, 32, stride=2, output_padding=1) + self.dsl3 = DSLayer(128, num_classes, 32, stride=4, output_padding=3) + + self.conv_out = nn.Sequential( + Conv3x3(WIDTH, WIDTH, norm=True, act=True), + Conv3x3(WIDTH, num_classes) + ) + + self.init_weight() + + def forward(self, t1, t2): + f1 = self.backbone(t1) + f2 = self.backbone(t2) + + y1 = self.decoder(f1) + y2 = self.decoder(f2) + + y1 = self.cbam1(y1) + y2 = self.cbam2(y2) + + out = paddle.abs(y1-y2) + out = F.interpolate(out, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True) + pred = self.conv_out(out) + + ds2 = self.dsl2(paddle.abs(f1[0]-f2[0])) + ds3 = self.dsl3(paddle.abs(f1[1]-f2[1])) + + return pred, ds2, ds3 + + def init_weight(self): + pass + + +class DSLayer(nn.Sequential): + def __init__(self, in_ch, out_ch, itm_ch, **convd_kwargs): + super().__init__( + nn.Conv2DTranspose(in_ch, itm_ch, kernel_size=3, padding=1, **convd_kwargs), + make_norm(itm_ch), + nn.ReLU(), + nn.Dropout2D(p=0.2), + nn.Conv2DTranspose(itm_ch, out_ch, kernel_size=3, padding=1) + ) \ No newline at end of file diff --git a/paddlers/models/cd/models/dsifn.py b/paddlers/models/cd/models/dsifn.py new file mode 100644 index 0000000..c8d8e28 --- /dev/null +++ b/paddlers/models/cd/models/dsifn.py @@ -0,0 +1,209 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.vision.models import vgg16 + + +from .layers import Conv1x1, make_norm, ChannelAttention, SpatialAttention + + +class DSIFN(nn.Layer): + """ + The DSIFN implementation based on PaddlePaddle. + + The original article refers to + C. Zhang, et al., "A deeply supervised image fusion network for change detection in high resolution bi-temporal remote + sensing images" + (https://www.sciencedirect.com/science/article/pii/S0924271620301532). + + Note that in this implementation, there is a flexible number of target classes. + + Args: + num_classes (int): The number of target classes. + use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained + on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. + """ + + def __init__(self, num_classes, use_dropout=False): + super().__init__() + + self.encoder1 = self.encoder2 = VGG16FeaturePicker() + + self.sa1 = SpatialAttention() + self.sa2= SpatialAttention() + self.sa3 = SpatialAttention() + self.sa4 = SpatialAttention() + self.sa5 = SpatialAttention() + + self.ca1 = ChannelAttention(in_ch=1024) + self.bn_ca1 = make_norm(1024) + self.o1_conv1 = conv2d_bn(1024, 512, use_dropout) + self.o1_conv2 = conv2d_bn(512, 512, use_dropout) + self.bn_sa1 = make_norm(512) + self.o1_conv3 = Conv1x1(512, num_classes) + self.trans_conv1 = nn.Conv2DTranspose(512, 512, kernel_size=2, stride=2) + + self.ca2 = ChannelAttention(in_ch=1536) + self.bn_ca2 = make_norm(1536) + self.o2_conv1 = conv2d_bn(1536, 512, use_dropout) + self.o2_conv2 = conv2d_bn(512, 256, use_dropout) + self.o2_conv3 = conv2d_bn(256, 256, use_dropout) + self.bn_sa2 = make_norm(256) + self.o2_conv4 = Conv1x1(256, num_classes) + self.trans_conv2 = nn.Conv2DTranspose(256, 256, kernel_size=2, stride=2) + + self.ca3 = ChannelAttention(in_ch=768) + self.o3_conv1 = conv2d_bn(768, 256, use_dropout) + self.o3_conv2 = conv2d_bn(256, 128, use_dropout) + self.o3_conv3 = conv2d_bn(128, 128, use_dropout) + self.bn_sa3 = make_norm(128) + self.o3_conv4 = Conv1x1(128, num_classes) + self.trans_conv3 = nn.Conv2DTranspose(128, 128, kernel_size=2, stride=2) + + self.ca4 = ChannelAttention(in_ch=384) + self.o4_conv1 = conv2d_bn(384, 128, use_dropout) + self.o4_conv2 = conv2d_bn(128, 64, use_dropout) + self.o4_conv3 = conv2d_bn(64, 64, use_dropout) + self.bn_sa4 = make_norm(64) + self.o4_conv4 = Conv1x1(64, num_classes) + self.trans_conv4 = nn.Conv2DTranspose(64, 64, kernel_size=2, stride=2) + + self.ca5 = ChannelAttention(in_ch=192) + self.o5_conv1 = conv2d_bn(192, 64, use_dropout) + self.o5_conv2 = conv2d_bn(64, 32, use_dropout) + self.o5_conv3 = conv2d_bn(32, 16, use_dropout) + self.bn_sa5 = make_norm(16) + self.o5_conv4 = Conv1x1(16, num_classes) + + self.init_weight() + + def forward(self, t1, t2): + # Extract bi-temporal features. + with paddle.no_grad(): + self.encoder1.eval(), self.encoder2.eval() + t1_feats = self.encoder1(t1) + t2_feats = self.encoder2(t2) + + t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_feats + t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29,= t2_feats + + # Multi-level decoding + x = paddle.concat([t1_f_l29, t2_f_l29], axis=1) + x = self.o1_conv1(x) + x = self.o1_conv2(x) + x = self.sa1(x) * x + x = self.bn_sa1(x) + + out1 = F.interpolate( + self.o1_conv3(x), + size=paddle.shape(t1)[2:], + mode='bilinear', + align_corners=True + ) + + x = self.trans_conv1(x) + x = paddle.concat([x, t1_f_l22, t2_f_l22], axis=1) + x = self.ca2(x)*x + x = self.o2_conv1(x) + x = self.o2_conv2(x) + x = self.o2_conv3(x) + x = self.sa2(x) *x + x = self.bn_sa2(x) + + out2 = F.interpolate( + self.o2_conv4(x), + size=paddle.shape(t1)[2:], + mode='bilinear', + align_corners=True + ) + + x = self.trans_conv2(x) + x = paddle.concat([x, t1_f_l15, t2_f_l15], axis=1) + x = self.ca3(x)*x + x = self.o3_conv1(x) + x = self.o3_conv2(x) + x = self.o3_conv3(x) + x = self.sa3(x) *x + x = self.bn_sa3(x) + + out3 = F.interpolate( + self.o3_conv4(x), + size=paddle.shape(t1)[2:], + mode='bilinear', + align_corners=True + ) + + x = self.trans_conv3(x) + x = paddle.concat([x, t1_f_l8, t2_f_l8], axis=1) + x = self.ca4(x)*x + x = self.o4_conv1(x) + x = self.o4_conv2(x) + x = self.o4_conv3(x) + x = self.sa4(x) *x + x = self.bn_sa4(x) + + out4 = F.interpolate( + self.o4_conv4(x), + size=paddle.shape(t1)[2:], + mode='bilinear', + align_corners=True + ) + + x = self.trans_conv4(x) + x = paddle.concat([x, t1_f_l3, t2_f_l3], axis=1) + x = self.ca5(x)*x + x = self.o5_conv1(x) + x = self.o5_conv2(x) + x = self.o5_conv3(x) + x = self.sa5(x) *x + x = self.bn_sa5(x) + + out5 = self.o5_conv4(x) + + return out5, out4, out3, out2, out1 + + def init_weight(self): + # Do nothing + pass + + +class VGG16FeaturePicker(nn.Layer): + def __init__(self, indices=(3,8,15,22,29)): + super().__init__() + features = list(vgg16(pretrained=True).features)[:30] + self.features = nn.LayerList(features) + self.features.eval() + self.indices = set(indices) + + def forward(self, x): + picked_feats = [] + for idx, model in enumerate(self.features): + x = model(x) + if idx in self.indices: + picked_feats.append(x) + return picked_feats + + +def conv2d_bn(in_ch, out_ch, with_dropout=True): + lst = [ + nn.Conv2D(in_ch, out_ch, kernel_size=3, stride=1, padding=1), + nn.PReLU(), + make_norm(out_ch), + ] + if with_dropout: + lst.append(nn.Dropout(p=0.6)) + return nn.Sequential(*lst) \ No newline at end of file diff --git a/paddlers/models/cd/models/layers/__init__.py b/paddlers/models/cd/models/layers/__init__.py new file mode 100644 index 0000000..ed9d985 --- /dev/null +++ b/paddlers/models/cd/models/layers/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .blocks import * +from .attention import ChannelAttention, SpatialAttention, CBAM \ No newline at end of file diff --git a/paddlers/models/cd/models/layers/attention.py b/paddlers/models/cd/models/layers/attention.py new file mode 100644 index 0000000..30d7954 --- /dev/null +++ b/paddlers/models/cd/models/layers/attention.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +from .blocks import Conv1x1, BasicConv + + +class ChannelAttention(nn.Layer): + """ + The channel attention module implementation based on PaddlePaddle. + + The original article refers to + Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module" + (https://arxiv.org/abs/1807.06521). + + Args: + in_ch (int): The number of channels of the input features. + ratio (int, optional): The channel reduction ratio. Default: 8. + """ + + def __init__(self, in_ch, ratio=8): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.max_pool = nn.AdaptiveMaxPool2D(1) + self.fc1 = Conv1x1(in_ch, in_ch//ratio, bias=False, act=True) + self.fc2 = Conv1x1(in_ch//ratio, in_ch, bias=False) + + def forward(self,x): + avg_out = self.fc2(self.fc1(self.avg_pool(x))) + max_out = self.fc2(self.fc1(self.max_pool(x))) + out = avg_out + max_out + return F.sigmoid(out) + + +class SpatialAttention(nn.Layer): + """ + The spatial attention module implementation based on PaddlePaddle. + + The original article refers to + Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module" + (https://arxiv.org/abs/1807.06521). + + Args: + kernel_size (int, optional): The size of the convolutional kernel. Default: 7. + """ + + def __init__(self, kernel_size=7): + super().__init__() + self.conv = BasicConv(2, 1, kernel_size, bias=False) + + def forward(self, x): + avg_out = paddle.mean(x, axis=1, keepdim=True) + max_out = paddle.max(x, axis=1, keepdim=True) + x = paddle.concat([avg_out, max_out], axis=1) + x = self.conv(x) + return F.sigmoid(x) + + +class CBAM(nn.Layer): + """ + The CBAM implementation based on PaddlePaddle. + + The original article refers to + Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module" + (https://arxiv.org/abs/1807.06521). + + Args: + in_ch (int): The number of channels of the input features. + ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8. + kernel_size (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7. + """ + + def __init__(self, in_ch, ratio=8, kernel_size=7): + super().__init__() + self.ca = ChannelAttention(in_ch, ratio=ratio) + self.sa = SpatialAttention(kernel_size=kernel_size) + + def forward(self, x): + y = self.ca(x) * x + y = self.sa(y) * y + return y \ No newline at end of file diff --git a/paddlers/models/cd/models/layers/blocks.py b/paddlers/models/cd/models/layers/blocks.py new file mode 100644 index 0000000..bfd07ba --- /dev/null +++ b/paddlers/models/cd/models/layers/blocks.py @@ -0,0 +1,142 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + + +__all__ = [ + 'BasicConv', 'Conv1x1', 'Conv3x3', 'Conv7x7', + 'MaxPool2x2', 'MaxUnPool2x2', + 'ConvTransposed3x3', + 'Identity', + 'get_norm_layer', 'get_act_layer', + 'make_norm', 'make_act' +] + + +def get_norm_layer(): + # TODO: select appropriate norm layer. + return nn.BatchNorm2D + + +def get_act_layer(): + # TODO: select appropriate activation layer. + return nn.ReLU + + +def make_norm(*args, **kwargs): + norm_layer = get_norm_layer() + return norm_layer(*args, **kwargs) + + +def make_act(*args, **kwargs): + act_layer = get_act_layer() + return act_layer(*args, **kwargs) + + +class BasicConv(nn.Layer): + def __init__( + self, in_ch, out_ch, + kernel_size, pad_mode='constant', + bias='auto', norm=False, act=False, + **kwargs + ): + super().__init__() + seq = [] + if kernel_size >= 2: + seq.append(nn.Pad2D(kernel_size//2, mode=pad_mode)) + seq.append( + nn.Conv2D( + in_ch, out_ch, kernel_size, + stride=1, padding=0, + bias_attr=(False if norm else None) if bias=='auto' else bias, + **kwargs + ) + ) + if norm: + if norm is True: + norm = make_norm(out_ch) + seq.append(norm) + if act: + if act is True: + act = make_act() + seq.append(act) + self.seq = nn.Sequential(*seq) + + def forward(self, x): + return self.seq(x) + + +class Conv1x1(BasicConv): + def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs): + super().__init__(in_ch, out_ch, 1, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs) + + +class Conv3x3(BasicConv): + def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs): + super().__init__(in_ch, out_ch, 3, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs) + + +class Conv7x7(BasicConv): + def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs): + super().__init__(in_ch, out_ch, 7, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs) + + +class MaxPool2x2(nn.MaxPool2D): + def __init__(self, **kwargs): + super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs) + + +class MaxUnPool2x2(nn.MaxUnPool2D): + def __init__(self, **kwargs): + super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs) + + +class ConvTransposed3x3(nn.Layer): + def __init__( + self, in_ch, out_ch, + bias='auto', norm=False, act=False, + **kwargs + ): + super().__init__() + seq = [] + seq.append( + nn.Conv2DTranspose( + in_ch, out_ch, 3, + stride=2, padding=1, + bias_attr=(False if norm else None) if bias=='auto' else bias, + **kwargs + ) + ) + if norm: + if norm is True: + norm = make_norm(out_ch) + seq.append(norm) + if act: + if act is True: + act = make_act() + seq.append(act) + self.seq = nn.Sequential(*seq) + + def forward(self, x): + return self.seq(x) + + +class Identity(nn.Layer): + """A placeholder identity operator that accepts exactly one argument.""" + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x): + return x \ No newline at end of file diff --git a/paddlers/models/cd/models/param_init.py b/paddlers/models/cd/models/param_init.py new file mode 100644 index 0000000..5e578f0 --- /dev/null +++ b/paddlers/models/cd/models/param_init.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn +import paddle.nn.functional as F + + +def normal_init(param, *args, **kwargs): + """ + Initialize parameters with a normal distribution. + + Args: + param (Tensor): The tensor that needs to be initialized. + + Returns: + The initialized parameters. + """ + + return nn.initializer.Normal(*args, **kwargs)(param) + + +def kaiming_normal_init(param, *args, **kwargs): + """ + Initialize parameters with the Kaiming normal distribution. + + For more information about the Kaiming initialization method, please refer to + https://arxiv.org/abs/1502.01852 + + Args: + param (Tensor): The tensor that needs to be initialized. + + Returns: + The initialized parameters. + """ + + return nn.initializer.KaimingNormal(*args, **kwargs)(param) + + +def constant_init(param, *args, **kwargs): + """ + Initialize parameters with constants. + + Args: + param (Tensor): The tensor that needs to be initialized. + + Returns: + The initialized parameters. + """ + + return nn.initializer.Constant(*args, **kwargs)(param) + + +class KaimingInitMixin: + """ + A mix-in that provides the Kaiming initialization functionality. + + Examples: + + from paddlers.models.cd.models.param_init import KaimingInitMixin + + class CustomNet(nn.Layer, KaimingInitMixin): + def __init__(self, num_channels, num_classes): + super().__init__() + self.conv = nn.Conv2D(num_channels, num_classes, 3, 1, 0, bias_attr=False) + self.bn = nn.BatchNorm2D(num_classes) + self.init_weight() + """ + + def init_weight(self): + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2D): + kaiming_normal_init(layer.weight) + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + constant_init(layer.weight, value=1) + constant_init(layer.bias, value=0) \ No newline at end of file diff --git a/paddlers/models/cd/models/snunet.py b/paddlers/models/cd/models/snunet.py new file mode 100644 index 0000000..ecc86df --- /dev/null +++ b/paddlers/models/cd/models/snunet.py @@ -0,0 +1,155 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +from .layers import Conv1x1, MaxPool2x2, make_norm, ChannelAttention +from .param_init import KaimingInitMixin + + +class SNUNet(nn.Layer, KaimingInitMixin): + """ + The SNUNet implementation based on PaddlePaddle. + + The original article refers to + S. Fang, et al., "SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images" + (https://ieeexplore.ieee.org/document/9355573). + + Note that bilinear interpolation is adopted as the upsampling method, which is different from the paper. + + Args: + in_channels (int): The number of bands of the input images. + num_classes (int): The number of target classes. + width (int, optional): The output channels of the first convolutional layer. Default: 32. + """ + + def __init__(self, in_channels, num_classes, width=32): + super().__init__() + + filters = (width, width*2, width*4, width*8, width*16) + + self.conv0_0 = ConvBlockNested(in_channels, filters[0], filters[0]) + self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1]) + self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2]) + self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3]) + self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4]) + self.down1 = MaxPool2x2() + self.down2 = MaxPool2x2() + self.down3 = MaxPool2x2() + self.down4 = MaxPool2x2() + self.up1_0 = Up(filters[1]) + self.up2_0 = Up(filters[2]) + self.up3_0 = Up(filters[3]) + self.up4_0 = Up(filters[4]) + + self.conv0_1 = ConvBlockNested(filters[0]*2+filters[1], filters[0], filters[0]) + self.conv1_1 = ConvBlockNested(filters[1]*2+filters[2], filters[1], filters[1]) + self.conv2_1 = ConvBlockNested(filters[2]*2+filters[3], filters[2], filters[2]) + self.conv3_1 = ConvBlockNested(filters[3]*2+filters[4], filters[3], filters[3]) + self.up1_1 = Up(filters[1]) + self.up2_1 = Up(filters[2]) + self.up3_1 = Up(filters[3]) + + self.conv0_2 = ConvBlockNested(filters[0]*3+filters[1], filters[0], filters[0]) + self.conv1_2 = ConvBlockNested(filters[1]*3+filters[2], filters[1], filters[1]) + self.conv2_2 = ConvBlockNested(filters[2]*3+filters[3], filters[2], filters[2]) + self.up1_2 = Up(filters[1]) + self.up2_2 = Up(filters[2]) + + self.conv0_3 = ConvBlockNested(filters[0]*4+filters[1], filters[0], filters[0]) + self.conv1_3 = ConvBlockNested(filters[1]*4+filters[2], filters[1], filters[1]) + self.up1_3 = Up(filters[1]) + + self.conv0_4 = ConvBlockNested(filters[0]*5+filters[1], filters[0], filters[0]) + + self.ca_intra = ChannelAttention(filters[0], ratio=4) + self.ca_inter = ChannelAttention(filters[0]*4, ratio=16) + + self.conv_out = Conv1x1(filters[0]*4, num_classes) + + self.init_weight() + + def forward(self, t1, t2): + x0_0_t1 = self.conv0_0(t1) + x1_0_t1 = self.conv1_0(self.down1(x0_0_t1)) + x2_0_t1 = self.conv2_0(self.down2(x1_0_t1)) + x3_0_t1 = self.conv3_0(self.down3(x2_0_t1)) + + x0_0_t2 = self.conv0_0(t2) + x1_0_t2 = self.conv1_0(self.down1(x0_0_t2)) + x2_0_t2 = self.conv2_0(self.down2(x1_0_t2)) + x3_0_t2 = self.conv3_0(self.down3(x2_0_t2)) + x4_0_t2 = self.conv4_0(self.down4(x3_0_t2)) + + x0_1 = self.conv0_1(paddle.concat([x0_0_t1, x0_0_t2, self.up1_0(x1_0_t2)], 1)) + x1_1 = self.conv1_1(paddle.concat([x1_0_t1, x1_0_t2, self.up2_0(x2_0_t2)], 1)) + x0_2 = self.conv0_2(paddle.concat([x0_0_t1, x0_0_t2, x0_1, self.up1_1(x1_1)], 1)) + + x2_1 = self.conv2_1(paddle.concat([x2_0_t1, x2_0_t2, self.up3_0(x3_0_t2)], 1)) + x1_2 = self.conv1_2(paddle.concat([x1_0_t1, x1_0_t2, x1_1, self.up2_1(x2_1)], 1)) + x0_3 = self.conv0_3(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, self.up1_2(x1_2)], 1)) + + x3_1 = self.conv3_1(paddle.concat([x3_0_t1, x3_0_t2, self.up4_0(x4_0_t2)], 1)) + x2_2 = self.conv2_2(paddle.concat([x2_0_t1, x2_0_t2, x2_1, self.up3_1(x3_1)], 1)) + x1_3 = self.conv1_3(paddle.concat([x1_0_t1, x1_0_t2, x1_1, x1_2, self.up2_2(x2_2)], 1)) + x0_4 = self.conv0_4(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1)) + + out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1) + + intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0) + m_intra = self.ca_intra(intra) + out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1,4,1,1))) + + pred = self.conv_out(out) + return pred, + + +class ConvBlockNested(nn.Layer): + def __init__(self, in_ch, out_ch, mid_ch): + super().__init__() + self.act = nn.ReLU() + self.conv1 = nn.Conv2D(in_ch, mid_ch, kernel_size=3, padding=1) + self.bn1 = make_norm(mid_ch) + self.conv2 = nn.Conv2D(mid_ch, out_ch, kernel_size=3, padding=1) + self.bn2 = make_norm(out_ch) + + def forward(self, x): + x = self.conv1(x) + identity = x + x = self.bn1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.bn2(x) + output = self.act(x + identity) + return output + + +class Up(nn.Layer): + def __init__(self, in_ch, use_conv=False): + super().__init__() + if use_conv: + self.up = nn.Conv2DTranspose(in_ch, in_ch, 2, stride=2) + else: + self.up = nn.Upsample(scale_factor=2, + mode='bilinear', + align_corners=True) + + def forward(self, x): + x = self.up(x) + return x \ No newline at end of file diff --git a/paddlers/models/cd/models/stanet.py b/paddlers/models/cd/models/stanet.py new file mode 100644 index 0000000..967d361 --- /dev/null +++ b/paddlers/models/cd/models/stanet.py @@ -0,0 +1,298 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +from .backbones import resnet +from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity +from .param_init import KaimingInitMixin + + +class STANet(nn.Layer): + """ + The STANet implementation based on PaddlePaddle. + + The original article refers to + H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection" + (https://www.mdpi.com/2072-4292/12/10/1662). + + Note that this implementation differs from the original work in two aspects: + 1. We do not use multiple dilation rates in layer 4 of the ResNet backbone. + 2. A classification head is used in place of the original metric learning-based head to stablize the training process. + + Args: + in_channels (int): The number of bands of the input images. + num_classes (int): The number of target classes. + att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'. + ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values + greater than 1, the input features will first be processed by an average pooling layer with the kernel size of + `ds_factor`, before being used to calculate the attention scores. Default: 1. + + Raises: + ValueError: When `att_type` has an illeagal value (unsupported attention type). + """ + + def __init__( + self, + in_channels, + num_classes, + att_type='BAM', + ds_factor=1 + ): + super().__init__() + + WIDTH = 64 + + self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH) + self.attend = build_sta_module(in_ch=WIDTH, att_type=att_type, ds=ds_factor) + self.conv_out = nn.Sequential( + Conv3x3(WIDTH, WIDTH, norm=True, act=True), + Conv3x3(WIDTH, num_classes) + ) + + self.init_weight() + + def forward(self, t1, t2): + f1 = self.extract(t1) + f2 = self.extract(t2) + + f1, f2 = self.attend(f1, f2) + + y = paddle.abs(f1- f2) + y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True) + + pred = self.conv_out(y) + return pred, + + def init_weight(self): + # Do nothing here as the encoder and decoder weights have already been initialized. + # Note however that currently self.attend and self.conv_out use the default initilization method. + pass + + +def build_feat_extractor(in_ch, width): + return nn.Sequential( + Backbone(in_ch, 'resnet18'), + Decoder(width) + ) + + +def build_sta_module(in_ch, att_type, ds): + if att_type == 'BAM': + return Attention(BAM(in_ch, ds)) + elif att_type == 'PAM': + return Attention(PAM(in_ch, ds)) + else: + raise ValueError + + +class Backbone(nn.Layer, KaimingInitMixin): + def __init__(self, in_ch, arch, pretrained=True, strides=(2,1,2,2,2)): + super().__init__() + + if arch == 'resnet18': + self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) + elif arch == 'resnet34': + self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) + elif arch == 'resnet50': + self.resnet = resnet.resnet50(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) + else: + raise ValueError + + self._trim_resnet() + + if in_ch != 3: + self.resnet.conv1 = nn.Conv2D( + in_ch, + 64, + kernel_size=7, + stride=strides[0], + padding=3, + bias_attr=False + ) + + if not pretrained: + self.init_weight() + + def forward(self, x): + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x = self.resnet.relu(x) + x = self.resnet.maxpool(x) + + x1 = self.resnet.layer1(x) + x2 = self.resnet.layer2(x1) + x3 = self.resnet.layer3(x2) + x4 = self.resnet.layer4(x3) + + return x1, x2, x3, x4 + + def _trim_resnet(self): + self.resnet.avgpool = Identity() + self.resnet.fc = Identity() + + +class Decoder(nn.Layer, KaimingInitMixin): + def __init__(self, f_ch): + super().__init__() + self.dr1 = Conv1x1(64, 96, norm=True, act=True) + self.dr2 = Conv1x1(128, 96, norm=True, act=True) + self.dr3 = Conv1x1(256, 96, norm=True, act=True) + self.dr4 = Conv1x1(512, 96, norm=True, act=True) + self.conv_out = nn.Sequential( + Conv3x3(384, 256, norm=True, act=True), + nn.Dropout(0.5), + Conv1x1(256, f_ch, norm=True, act=True) + ) + + self.init_weight() + + def forward(self, feats): + f1 = self.dr1(feats[0]) + f2 = self.dr2(feats[1]) + f3 = self.dr3(feats[2]) + f4 = self.dr4(feats[3]) + + f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) + f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) + f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) + + x = paddle.concat([f1, f2, f3, f4], axis=1) + y = self.conv_out(x) + + return y + + +class BAM(nn.Layer): + def __init__(self, in_ch, ds): + super().__init__() + + self.ds = ds + self.pool = nn.AvgPool2D(self.ds) + + self.val_ch = in_ch + self.key_ch = in_ch // 8 + self.conv_q = Conv1x1(in_ch, self.key_ch) + self.conv_k = Conv1x1(in_ch, self.key_ch) + self.conv_v = Conv1x1(in_ch, self.val_ch) + + self.softmax = nn.Softmax(axis=-1) + + def forward(self, x): + x = x.flatten(-2) + x_rs = self.pool(x) + + b, c, h, w = paddle.shape(x_rs) + query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1)) + key = self.conv_k(x_rs).reshape((b,-1,h*w)) + energy = paddle.bmm(query, key) + energy = (self.key_ch**(-0.5)) * energy + + attention = self.softmax(energy) + + value = self.conv_v(x_rs).reshape((b,-1,w*h)) + + out = paddle.bmm(value, attention.transpose((0,2,1))) + out = out.reshape((b,c,h,w)) + + out = F.interpolate(out, scale_factor=self.ds) + out = out + x + return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2]) + + +class PAMBlock(nn.Layer): + def __init__(self, in_ch, scale=1, ds=1): + super().__init__() + + self.scale = scale + self.ds = ds + self.pool = nn.AvgPool2D(self.ds) + + self.val_ch = in_ch + self.key_ch = in_ch // 8 + self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True) + self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True) + self.conv_v = Conv1x1(in_ch, self.val_ch) + + def forward(self, x): + x_rs = self.pool(x) + + # Get query, key, and value. + query = self.conv_q(x_rs) + key = self.conv_k(x_rs) + value = self.conv_v(x_rs) + + # Split the whole image into subregions. + b, c, h, w = paddle.shape(x_rs) + query = self._split_subregions(query) + key = self._split_subregions(key) + value = self._split_subregions(value) + + # Perform subregion-wise attention. + out = self._attend(query, key, value) + + # Stack subregions to reconstruct the whole image. + out = self._recons_whole(out, b, c, h, w) + out = F.interpolate(out, scale_factor=self.ds) + return out + + def _attend(self, query, key, value): + energy = paddle.bmm(query.transpose((0,2,1)), key) # batch matrix multiplication + energy = (self.key_ch**(-0.5)) * energy + attention = F.softmax(energy, axis=-1) + out = paddle.bmm(value, attention.transpose((0,2,1))) + return out + + def _split_subregions(self, x): + b, c, h, w = paddle.shape(x) + assert h % self.scale == 0 and w % self.scale == 0 + x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale)) + x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1)) + return x + + def _recons_whole(self, x, b, c, h, w): + x = x.reshape((b, self.scale, self.scale, c, h//self.scale, w//self.scale)) + x = x.transpose((0,3,1,4,2,5)).reshape((b, c, h, w)) + return x + + +class PAM(nn.Layer): + def __init__(self, in_ch, ds, scales=(1,2,4,8)): + super().__init__() + + self.stages = nn.LayerList([ + PAMBlock(in_ch, scale=s, ds=ds) + for s in scales + ]) + self.conv_out = Conv1x1(in_ch*len(scales), in_ch, bias=False) + + def forward(self, x): + x = x.flatten(-2) + res = [stage(x) for stage in self.stages] + out = self.conv_out(paddle.concat(res, axis=1)) + return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2]) + + +class Attention(nn.Layer): + def __init__(self, att): + super().__init__() + self.att = att + + def forward(self, x1, x2): + x = paddle.stack([x1, x2], axis=-1) + y = self.att(x) + return y[...,0], y[...,1] \ No newline at end of file diff --git a/paddlers/models/cd/models/unet_ef.py b/paddlers/models/cd/models/unet_ef.py new file mode 100644 index 0000000..5f099f5 --- /dev/null +++ b/paddlers/models/cd/models/unet_ef.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity +from .param_init import normal_init, constant_init + + +class UNetEarlyFusion(nn.Layer): + """ + The FC-EF implementation based on PaddlePaddle. + + The original article refers to + Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection" + (https://arxiv.org/abs/1810.08462). + + Args: + in_channels (int): The number of bands of the input images. + num_classes (int): The number of target classes. + use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained + on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. + """ + + def __init__( + self, + in_channels, + num_classes, + use_dropout=False + ): + super().__init__() + + C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256 + + self.use_dropout = use_dropout + + self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True) + self.do11 = self._make_dropout() + self.conv12 = Conv3x3(C1, C1, norm=True, act=True) + self.do12 = self._make_dropout() + self.pool1 = MaxPool2x2() + + self.conv21 = Conv3x3(C1, C2, norm=True, act=True) + self.do21 = self._make_dropout() + self.conv22 = Conv3x3(C2, C2, norm=True, act=True) + self.do22 = self._make_dropout() + self.pool2 = MaxPool2x2() + + self.conv31 = Conv3x3(C2, C3, norm=True, act=True) + self.do31 = self._make_dropout() + self.conv32 = Conv3x3(C3, C3, norm=True, act=True) + self.do32 = self._make_dropout() + self.conv33 = Conv3x3(C3, C3, norm=True, act=True) + self.do33 = self._make_dropout() + self.pool3 = MaxPool2x2() + + self.conv41 = Conv3x3(C3, C4, norm=True, act=True) + self.do41 = self._make_dropout() + self.conv42 = Conv3x3(C4, C4, norm=True, act=True) + self.do42 = self._make_dropout() + self.conv43 = Conv3x3(C4, C4, norm=True, act=True) + self.do43 = self._make_dropout() + self.pool4 = MaxPool2x2() + + self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) + + self.conv43d = Conv3x3(C5, C4, norm=True, act=True) + self.do43d = self._make_dropout() + self.conv42d = Conv3x3(C4, C4, norm=True, act=True) + self.do42d = self._make_dropout() + self.conv41d = Conv3x3(C4, C3, norm=True, act=True) + self.do41d = self._make_dropout() + + self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) + + self.conv33d = Conv3x3(C4, C3, norm=True, act=True) + self.do33d = self._make_dropout() + self.conv32d = Conv3x3(C3, C3, norm=True, act=True) + self.do32d = self._make_dropout() + self.conv31d = Conv3x3(C3, C2, norm=True, act=True) + self.do31d = self._make_dropout() + + self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) + + self.conv22d = Conv3x3(C3, C2, norm=True, act=True) + self.do22d = self._make_dropout() + self.conv21d = Conv3x3(C2, C1, norm=True, act=True) + self.do21d = self._make_dropout() + + self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) + + self.conv12d = Conv3x3(C2, C1, norm=True, act=True) + self.do12d = self._make_dropout() + self.conv11d = Conv3x3(C1, num_classes) + + self.init_weight() + + def forward(self, t1, t2): + x = paddle.concat([t1, t2], axis=1) + + # Stage 1 + x11 = self.do11(self.conv11(x)) + x12 = self.do12(self.conv12(x11)) + x1p = self.pool1(x12) + + # Stage 2 + x21 = self.do21(self.conv21(x1p)) + x22 = self.do22(self.conv22(x21)) + x2p = self.pool2(x22) + + # Stage 3 + x31 = self.do31(self.conv31(x2p)) + x32 = self.do32(self.conv32(x31)) + x33 = self.do33(self.conv33(x32)) + x3p = self.pool3(x33) + + # Stage 4 + x41 = self.do41(self.conv41(x3p)) + x42 = self.do42(self.conv42(x41)) + x43 = self.do43(self.conv43(x42)) + x4p = self.pool4(x43) + + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ( + 0, + paddle.shape(x43)[3]-paddle.shape(x4d)[3], + 0, + paddle.shape(x43)[2]-paddle.shape(x4d)[2] + ) + x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1) + x43d = self.do43d(self.conv43d(x4d)) + x42d = self.do42d(self.conv42d(x43d)) + x41d = self.do41d(self.conv41d(x42d)) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ( + 0, + paddle.shape(x33)[3]-paddle.shape(x3d)[3], + 0, + paddle.shape(x33)[2]-paddle.shape(x3d)[2] + ) + x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1) + x33d = self.do33d(self.conv33d(x3d)) + x32d = self.do32d(self.conv32d(x33d)) + x31d = self.do31d(self.conv31d(x32d)) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ( + 0, + paddle.shape(x22)[3]-paddle.shape(x2d)[3], + 0, + paddle.shape(x22)[2]-paddle.shape(x2d)[2] + ) + x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1) + x22d = self.do22d(self.conv22d(x2d)) + x21d = self.do21d(self.conv21d(x22d)) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ( + 0, + paddle.shape(x12)[3]-paddle.shape(x1d)[3], + 0, + paddle.shape(x12)[2]-paddle.shape(x1d)[2] + ) + x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1) + x12d = self.do12d(self.conv12d(x1d)) + x11d = self.conv11d(x12d) + + return x11d, + + def init_weight(self): + for sublayer in self.sublayers(): + if isinstance(sublayer, nn.Conv2D): + normal_init(sublayer.weight, std=0.001) + elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)): + constant_init(sublayer.weight, value=1.0) + constant_init(sublayer.bias, value=0.0) + + def _make_dropout(self): + if self.use_dropout: + return nn.Dropout2D(p=0.2) + else: + return Identity() \ No newline at end of file diff --git a/paddlers/models/cd/models/unet_siamconc.py b/paddlers/models/cd/models/unet_siamconc.py new file mode 100644 index 0000000..df979b0 --- /dev/null +++ b/paddlers/models/cd/models/unet_siamconc.py @@ -0,0 +1,224 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity +from .param_init import normal_init, constant_init + + +class UNetSiamConc(nn.Layer): + """ + The FC-Siam-conc implementation based on PaddlePaddle. + + The original article refers to + Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection" + (https://arxiv.org/abs/1810.08462). + + Args: + in_channels (int): The number of bands of the input images. + num_classes (int): The number of target classes. + use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained + on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. + """ + + def __init__( + self, + in_channels, + num_classes, + use_dropout=False + ): + super().__init__() + + C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256 + + self.use_dropout = use_dropout + + self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True) + self.do11 = self._make_dropout() + self.conv12 = Conv3x3(C1, C1, norm=True, act=True) + self.do12 = self._make_dropout() + self.pool1 = MaxPool2x2() + + self.conv21 = Conv3x3(C1, C2, norm=True, act=True) + self.do21 = self._make_dropout() + self.conv22 = Conv3x3(C2, C2, norm=True, act=True) + self.do22 = self._make_dropout() + self.pool2 = MaxPool2x2() + + self.conv31 = Conv3x3(C2, C3, norm=True, act=True) + self.do31 = self._make_dropout() + self.conv32 = Conv3x3(C3, C3, norm=True, act=True) + self.do32 = self._make_dropout() + self.conv33 = Conv3x3(C3, C3, norm=True, act=True) + self.do33 = self._make_dropout() + self.pool3 = MaxPool2x2() + + self.conv41 = Conv3x3(C3, C4, norm=True, act=True) + self.do41 = self._make_dropout() + self.conv42 = Conv3x3(C4, C4, norm=True, act=True) + self.do42 = self._make_dropout() + self.conv43 = Conv3x3(C4, C4, norm=True, act=True) + self.do43 = self._make_dropout() + self.pool4 = MaxPool2x2() + + self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) + + self.conv43d = Conv3x3(C5+C4, C4, norm=True, act=True) + self.do43d = self._make_dropout() + self.conv42d = Conv3x3(C4, C4, norm=True, act=True) + self.do42d = self._make_dropout() + self.conv41d = Conv3x3(C4, C3, norm=True, act=True) + self.do41d = self._make_dropout() + + self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) + + self.conv33d = Conv3x3(C4+C3, C3, norm=True, act=True) + self.do33d = self._make_dropout() + self.conv32d = Conv3x3(C3, C3, norm=True, act=True) + self.do32d = self._make_dropout() + self.conv31d = Conv3x3(C3, C2, norm=True, act=True) + self.do31d = self._make_dropout() + + self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) + + self.conv22d = Conv3x3(C3+C2, C2, norm=True, act=True) + self.do22d = self._make_dropout() + self.conv21d = Conv3x3(C2, C1, norm=True, act=True) + self.do21d = self._make_dropout() + + self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) + + self.conv12d = Conv3x3(C2+C1, C1, norm=True, act=True) + self.do12d = self._make_dropout() + self.conv11d = Conv3x3(C1, num_classes) + + self.init_weight() + + def forward(self, t1, t2): + # Encode t1 + # Stage 1 + x11 = self.do11(self.conv11(t1)) + x12_1 = self.do12(self.conv12(x11)) + x1p = self.pool1(x12_1) + + # Stage 2 + x21 = self.do21(self.conv21(x1p)) + x22_1 = self.do22(self.conv22(x21)) + x2p = self.pool2(x22_1) + + # Stage 3 + x31 = self.do31(self.conv31(x2p)) + x32 = self.do32(self.conv32(x31)) + x33_1 = self.do33(self.conv33(x32)) + x3p = self.pool3(x33_1) + + # Stage 4 + x41 = self.do41(self.conv41(x3p)) + x42 = self.do42(self.conv42(x41)) + x43_1 = self.do43(self.conv43(x42)) + x4p = self.pool4(x43_1) + + # Encode t2 + # Stage 1 + x11 = self.do11(self.conv11(t2)) + x12_2 = self.do12(self.conv12(x11)) + x1p = self.pool1(x12_2) + + # Stage 2 + x21 = self.do21(self.conv21(x1p)) + x22_2 = self.do22(self.conv22(x21)) + x2p = self.pool2(x22_2) + + # Stage 3 + x31 = self.do31(self.conv31(x2p)) + x32 = self.do32(self.conv32(x31)) + x33_2 = self.do33(self.conv33(x32)) + x3p = self.pool3(x33_2) + + # Stage 4 + x41 = self.do41(self.conv41(x3p)) + x42 = self.do42(self.conv42(x41)) + x43_2 = self.do43(self.conv43(x42)) + x4p = self.pool4(x43_2) + + # Decode + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ( + 0, + paddle.shape(x43_1)[3]-paddle.shape(x4d)[3], + 0, + paddle.shape(x43_1)[2]-paddle.shape(x4d)[2] + ) + x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1) + x43d = self.do43d(self.conv43d(x4d)) + x42d = self.do42d(self.conv42d(x43d)) + x41d = self.do41d(self.conv41d(x42d)) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ( + 0, + paddle.shape(x33_1)[3]-paddle.shape(x3d)[3], + 0, + paddle.shape(x33_1)[2]-paddle.shape(x3d)[2] + ) + x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1) + x33d = self.do33d(self.conv33d(x3d)) + x32d = self.do32d(self.conv32d(x33d)) + x31d = self.do31d(self.conv31d(x32d)) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ( + 0, + paddle.shape(x22_1)[3]-paddle.shape(x2d)[3], + 0, + paddle.shape(x22_1)[2]-paddle.shape(x2d)[2] + ) + x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1) + x22d = self.do22d(self.conv22d(x2d)) + x21d = self.do21d(self.conv21d(x22d)) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ( + 0, + paddle.shape(x12_1)[3]-paddle.shape(x1d)[3], + 0, + paddle.shape(x12_1)[2]-paddle.shape(x1d)[2] + ) + x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1) + x12d = self.do12d(self.conv12d(x1d)) + x11d = self.conv11d(x12d) + + return x11d, + + def init_weight(self): + for sublayer in self.sublayers(): + if isinstance(sublayer, nn.Conv2D): + normal_init(sublayer.weight, std=0.001) + elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)): + constant_init(sublayer.weight, value=1.0) + constant_init(sublayer.bias, value=0.0) + + def _make_dropout(self): + if self.use_dropout: + return nn.Dropout2D(p=0.2) + else: + return Identity() \ No newline at end of file diff --git a/paddlers/models/cd/models/unet_siamdiff.py b/paddlers/models/cd/models/unet_siamdiff.py new file mode 100644 index 0000000..cca4d7c --- /dev/null +++ b/paddlers/models/cd/models/unet_siamdiff.py @@ -0,0 +1,224 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity +from .param_init import normal_init, constant_init + + +class UNetSiamDiff(nn.Layer): + """ + The FC-Siam-diff implementation based on PaddlePaddle. + + The original article refers to + Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection" + (https://arxiv.org/abs/1810.08462). + + Args: + in_channels (int): The number of bands of the input images. + num_classes (int): The number of target classes. + use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained + on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. + """ + + def __init__( + self, + in_channels, + num_classes, + use_dropout=False + ): + super().__init__() + + C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256 + + self.use_dropout = use_dropout + + self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True) + self.do11 = self._make_dropout() + self.conv12 = Conv3x3(C1, C1, norm=True, act=True) + self.do12 = self._make_dropout() + self.pool1 = MaxPool2x2() + + self.conv21 = Conv3x3(C1, C2, norm=True, act=True) + self.do21 = self._make_dropout() + self.conv22 = Conv3x3(C2, C2, norm=True, act=True) + self.do22 = self._make_dropout() + self.pool2 = MaxPool2x2() + + self.conv31 = Conv3x3(C2, C3, norm=True, act=True) + self.do31 = self._make_dropout() + self.conv32 = Conv3x3(C3, C3, norm=True, act=True) + self.do32 = self._make_dropout() + self.conv33 = Conv3x3(C3, C3, norm=True, act=True) + self.do33 = self._make_dropout() + self.pool3 = MaxPool2x2() + + self.conv41 = Conv3x3(C3, C4, norm=True, act=True) + self.do41 = self._make_dropout() + self.conv42 = Conv3x3(C4, C4, norm=True, act=True) + self.do42 = self._make_dropout() + self.conv43 = Conv3x3(C4, C4, norm=True, act=True) + self.do43 = self._make_dropout() + self.pool4 = MaxPool2x2() + + self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) + + self.conv43d = Conv3x3(C5, C4, norm=True, act=True) + self.do43d = self._make_dropout() + self.conv42d = Conv3x3(C4, C4, norm=True, act=True) + self.do42d = self._make_dropout() + self.conv41d = Conv3x3(C4, C3, norm=True, act=True) + self.do41d = self._make_dropout() + + self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) + + self.conv33d = Conv3x3(C4, C3, norm=True, act=True) + self.do33d = self._make_dropout() + self.conv32d = Conv3x3(C3, C3, norm=True, act=True) + self.do32d = self._make_dropout() + self.conv31d = Conv3x3(C3, C2, norm=True, act=True) + self.do31d = self._make_dropout() + + self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) + + self.conv22d = Conv3x3(C3, C2, norm=True, act=True) + self.do22d = self._make_dropout() + self.conv21d = Conv3x3(C2, C1, norm=True, act=True) + self.do21d = self._make_dropout() + + self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) + + self.conv12d = Conv3x3(C2, C1, norm=True, act=True) + self.do12d = self._make_dropout() + self.conv11d = Conv3x3(C1, num_classes) + + self.init_weight() + + def forward(self, t1, t2): + # Encode t1 + # Stage 1 + x11 = self.do11(self.conv11(t1)) + x12_1 = self.do12(self.conv12(x11)) + x1p = self.pool1(x12_1) + + # Stage 2 + x21 = self.do21(self.conv21(x1p)) + x22_1 = self.do22(self.conv22(x21)) + x2p = self.pool2(x22_1) + + # Stage 3 + x31 = self.do31(self.conv31(x2p)) + x32 = self.do32(self.conv32(x31)) + x33_1 = self.do33(self.conv33(x32)) + x3p = self.pool3(x33_1) + + # Stage 4 + x41 = self.do41(self.conv41(x3p)) + x42 = self.do42(self.conv42(x41)) + x43_1 = self.do43(self.conv43(x42)) + x4p = self.pool4(x43_1) + + # Encode t2 + # Stage 1 + x11 = self.do11(self.conv11(t2)) + x12_2 = self.do12(self.conv12(x11)) + x1p = self.pool1(x12_2) + + # Stage 2 + x21 = self.do21(self.conv21(x1p)) + x22_2 = self.do22(self.conv22(x21)) + x2p = self.pool2(x22_2) + + # Stage 3 + x31 = self.do31(self.conv31(x2p)) + x32 = self.do32(self.conv32(x31)) + x33_2 = self.do33(self.conv33(x32)) + x3p = self.pool3(x33_2) + + # Stage 4 + x41 = self.do41(self.conv41(x3p)) + x42 = self.do42(self.conv42(x41)) + x43_2 = self.do43(self.conv43(x42)) + x4p = self.pool4(x43_2) + + # Decode + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ( + 0, + paddle.shape(x43_1)[3]-paddle.shape(x4d)[3], + 0, + paddle.shape(x43_1)[2]-paddle.shape(x4d)[2] + ) + x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), paddle.abs(x43_1-x43_2)], 1) + x43d = self.do43d(self.conv43d(x4d)) + x42d = self.do42d(self.conv42d(x43d)) + x41d = self.do41d(self.conv41d(x42d)) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ( + 0, + paddle.shape(x33_1)[3]-paddle.shape(x3d)[3], + 0, + paddle.shape(x33_1)[2]-paddle.shape(x3d)[2] + ) + x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), paddle.abs(x33_1-x33_2)], 1) + x33d = self.do33d(self.conv33d(x3d)) + x32d = self.do32d(self.conv32d(x33d)) + x31d = self.do31d(self.conv31d(x32d)) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ( + 0, + paddle.shape(x22_1)[3]-paddle.shape(x2d)[3], + 0, + paddle.shape(x22_1)[2]-paddle.shape(x2d)[2] + ) + x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), paddle.abs(x22_1-x22_2)], 1) + x22d = self.do22d(self.conv22d(x2d)) + x21d = self.do21d(self.conv21d(x22d)) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ( + 0, + paddle.shape(x12_1)[3]-paddle.shape(x1d)[3], + 0, + paddle.shape(x12_1)[2]-paddle.shape(x1d)[2] + ) + x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), paddle.abs(x12_1-x12_2)], 1) + x12d = self.do12d(self.conv12d(x1d)) + x11d = self.conv11d(x12d) + + return x11d, + + def init_weight(self): + for sublayer in self.sublayers(): + if isinstance(sublayer, nn.Conv2D): + normal_init(sublayer.weight, std=0.001) + elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)): + constant_init(sublayer.weight, value=1.0) + constant_init(sublayer.bias, value=0.0) + + def _make_dropout(self): + if self.use_dropout: + return nn.Dropout2D(p=0.2) + else: + return Identity() \ No newline at end of file diff --git a/paddlers/tasks/changedetector.py b/paddlers/tasks/changedetector.py index 169dc15..d01390c 100644 --- a/paddlers/tasks/changedetector.py +++ b/paddlers/tasks/changedetector.py @@ -31,7 +31,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict from paddlers.transforms import ImgDecoder, Resize import paddlers.models.cd as cd -__all__ = ["CDNet"] +__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT", "SNUNet", "DSIFN", "DSAMNet"] class BaseChangeDetector(BaseModel): @@ -663,3 +663,190 @@ class CDNet(BaseChangeDetector): num_classes=num_classes, use_mixed_loss=use_mixed_loss, **params) + + +class UNetEarlyFusion(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + in_channels=6, + use_dropout=False, + **params): + params.update({ + 'in_channels': in_channels, + 'use_dropout': use_dropout + }) + super(UNetEarlyFusion, self).__init__( + model_name='UNetEarlyFusion', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + + +class UNetSiamConc(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + in_channels=3, + use_dropout=False, + **params): + params.update({ + 'in_channels': in_channels, + 'use_dropout': use_dropout + }) + super(UNetSiamConc, self).__init__( + model_name='UNetSiamConc', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + + +class UNetSiamDiff(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + in_channels=3, + use_dropout=False, + **params): + params.update({ + 'in_channels': in_channels, + 'use_dropout': use_dropout + }) + super(UNetSiamDiff, self).__init__( + model_name='UNetSiamDiff', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + + +class STANet(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + in_channels=3, + att_type='BAM', + ds_factor=1, + **params): + params.update({ + 'in_channels': in_channels, + 'att_type': att_type, + 'ds_factor': ds_factor + }) + super(STANet, self).__init__( + model_name='STANet', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + + +class BIT(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + in_channels=3, + backbone='resnet18', + n_stages=4, + use_tokenizer=True, + token_len=4, + pool_mode='max', + pool_size=2, + enc_with_pos=True, + enc_depth=1, + enc_head_dim=64, + dec_depth=8, + dec_head_dim=8, + **params): + params.update({ + 'in_channels': in_channels, + 'backbone': backbone, + 'n_stages': n_stages, + 'use_tokenizer': use_tokenizer, + 'token_len': token_len, + 'pool_mode': pool_mode, + 'pool_size': pool_size, + 'enc_with_pos': enc_with_pos, + 'enc_depth': enc_depth, + 'enc_head_dim': enc_head_dim, + 'dec_depth': dec_depth, + 'dec_head_dim': dec_head_dim + }) + super(BIT, self).__init__( + model_name='BIT', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + + +class SNUNet(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=False, + in_channels=3, + width=32, + **params): + params.update({ + 'in_channels': in_channels, + 'width': width + }) + super(SNUNet, self).__init__( + model_name='SNUNet', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + + +class DSIFN(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=None, + use_dropout=False, + **params): + params.update({ + 'use_dropout': use_dropout + }) + super(DSIFN, self).__init__( + model_name='DSIFN', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + # HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are + # constructed automatically. + assert use_mixed_loss is None + if use_mixed_loss is None: + self.losses = { + # XXX: make sure the shallow copy works correctly here. + 'types': [paddleseg.models.CrossEntropyLoss()]*5, + 'coef': [1.0]*5 + } + + +class DSAMNet(BaseChangeDetector): + def __init__(self, + num_classes=2, + use_mixed_loss=None, + in_channels=3, + ca_ratio=8, + sa_kernel=7, + **params): + params.update({ + 'in_channels': in_channels, + 'ca_ratio': ca_ratio, + 'sa_kernel': sa_kernel + }) + super(DSAMNet, self).__init__( + model_name='DSAMNet', + num_classes=num_classes, + use_mixed_loss=use_mixed_loss, + **params) + # HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are + # constructed automatically. + assert use_mixed_loss is None + if use_mixed_loss is None: + self.losses = { + 'types': [ + paddleseg.models.CrossEntropyLoss(), + paddleseg.models.DiceLoss(), + paddleseg.models.DiceLoss() + ], + 'coef': [1.0, 0.05, 0.05] + } \ No newline at end of file