Merge pull request #12 from Bobholamovic/dev_gdf

[Feature] Add eight change detection models
own
Liu Yi 3 years ago committed by GitHub
commit d5a7ddaf6f
  1. 8
      paddlers/models/cd/models/__init__.py
  2. 13
      paddlers/models/cd/models/backbones/__init__.py
  3. 358
      paddlers/models/cd/models/backbones/resnet.py
  4. 395
      paddlers/models/cd/models/bit.py
  5. 96
      paddlers/models/cd/models/dsamnet.py
  6. 209
      paddlers/models/cd/models/dsifn.py
  7. 16
      paddlers/models/cd/models/layers/__init__.py
  8. 96
      paddlers/models/cd/models/layers/attention.py
  9. 142
      paddlers/models/cd/models/layers/blocks.py
  10. 86
      paddlers/models/cd/models/param_init.py
  11. 155
      paddlers/models/cd/models/snunet.py
  12. 298
      paddlers/models/cd/models/stanet.py
  13. 201
      paddlers/models/cd/models/unet_ef.py
  14. 224
      paddlers/models/cd/models/unet_siamconc.py
  15. 224
      paddlers/models/cd/models/unet_siamdiff.py
  16. 189
      paddlers/tasks/changedetector.py

@ -12,4 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .bit import BIT
from .cdnet import CDNet
from .dsifn import DSIFN
from .stanet import STANet
from .snunet import SNUNet
from .dsamnet import DSAMNet
from .unet_ef import UNetEarlyFusion
from .unet_siamconc import UNetSiamConc
from .unet_siamdiff import UNetSiamDiff

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,358 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py
## Original head information
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
__all__ = []
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'cf548f46534aa3560945be4b95cd11c4'),
'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',
'8d2275cf8706028345f78ac0e1d31969'),
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'ca6f485ee1ab0492d38f323885b0ad80'),
'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',
'02f35f034ca3858e1e54d4036443c92d'),
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
'7ad16a2f1e7333859ff986138630fd7a'),
}
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
self.conv1 = nn.Conv2D(
inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(nn.Layer):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BottleneckBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2D(
width,
width,
3,
padding=dilation,
stride=stride,
groups=groups,
dilation=dilation,
bias_attr=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2D(
width, planes * self.expansion, 1, bias_attr=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Layer):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
"""
def __init__(self, block, depth, num_classes=1000, with_pool=True, strides=(1,1,2,2,2), norm_layer=None):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.num_classes = num_classes
self.with_pool = with_pool
self._norm_layer = nn.BatchNorm2D if norm_layer is None else norm_layer
self.inplanes = 64
self.dilation = 1
self.conv1 = nn.Conv2D(
3,
self.inplanes,
kernel_size=7,
stride=strides[0],
padding=3,
bias_attr=False)
self.bn1 = self._norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[1])
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2])
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3])
self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4])
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
if num_classes > 0:
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
1,
stride=stride,
bias_attr=False),
norm_layer(planes * block.expansion), )
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, 1, 64,
previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
def resnet18(pretrained=False, **kwargs):
"""ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet18
# build model
model = resnet18()
# build model and load imagenet pretrained weight
# model = resnet18(pretrained=True)
"""
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
"""ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet34
# build model
model = resnet34()
# build model and load imagenet pretrained weight
# model = resnet34(pretrained=True)
"""
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
"""ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet50
# build model
model = resnet50()
# build model and load imagenet pretrained weight
# model = resnet50(pretrained=True)
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
"""ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet101
# build model
model = resnet101()
# build model and load imagenet pretrained weight
# model = resnet101(pretrained=True)
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
"""ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet152
# build model
model = resnet152()
# build model and load imagenet pretrained weight
# model = resnet152(pretrained=True)
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)

@ -0,0 +1,395 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal
from .backbones import resnet
from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity
from .param_init import KaimingInitMixin
class BIT(nn.Layer):
"""
The BIT implementation based on PaddlePaddle.
The original article refers to
H. Chen, et al., "Remote Sensing Image Change Detection With Transformers"
(https://arxiv.org/abs/2103.00208).
This implementation adopts pretrained encoders, as opposed to the original work where weights are randomly initialized.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
backbone (str, optional): The ResNet architecture that is used as the backbone. Currently, only 'resnet18' and
'resnet34' are supported. Default: 'resnet18'.
n_stages (int, optional): The number of ResNet stages used in the backbone, which should be a value in {3,4,5}.
Default: 4.
use_tokenizer (bool, optional): Use a tokenizer or not. Default: True.
token_len (int, optional): The length of input tokens. Default: 4.
pool_mode (str, optional): The pooling strategy to obtain input tokens when `use_tokenizer` is set to False. 'max'
for global max pooling and 'avg' for global average pooling. Default: 'max'.
pool_size (int, optional): The height and width of the pooled feature maps when `use_tokenizer` is set to False.
Default: 2.
enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the
encoder. Default: True.
enc_depth (int, optional): The number of attention blocks used in the encoder. Default: 1
enc_head_dim (int, optional): The embedding dimension of each encoder head. Default: 64.
dec_depth (int, optional): The number of attention blocks used in the decoder. Default: 8.
dec_head_dim (int, optional): The embedding dimension of each decoder head. Default: 8.
Raises:
ValueError: When an unsupported backbone type is specified, or the number of backbone stages is not 3, 4, or 5.
"""
def __init__(
self, in_channels, num_classes,
backbone='resnet18', n_stages=4,
use_tokenizer=True, token_len=4,
pool_mode='max', pool_size=2,
enc_with_pos=True,
enc_depth=1, enc_head_dim=64,
dec_depth=8, dec_head_dim=8,
**backbone_kwargs
):
super().__init__()
# TODO: reduce hard-coded parameters
DIM = 32
MLP_DIM = 2*DIM
EBD_DIM = DIM
self.backbone = Backbone(in_channels, EBD_DIM, arch=backbone, n_stages=n_stages, **backbone_kwargs)
self.use_tokenizer = use_tokenizer
if not use_tokenizer:
# If a tokenzier is not to be used,then downsample the feature maps.
self.pool_size = pool_size
self.pool_mode = pool_mode
self.token_len = pool_size * pool_size
else:
self.conv_att = Conv1x1(32, token_len, bias=False)
self.token_len = token_len
self.enc_with_pos = enc_with_pos
if enc_with_pos:
self.enc_pos_embedding = self.create_parameter(
shape=(1,self.token_len*2,EBD_DIM),
default_initializer=Normal()
)
self.enc_depth = enc_depth
self.dec_depth = dec_depth
self.enc_head_dim = enc_head_dim
self.dec_head_dim = dec_head_dim
self.encoder = TransformerEncoder(
dim=DIM,
depth=enc_depth,
n_heads=8,
head_dim=enc_head_dim,
mlp_dim=MLP_DIM,
dropout_rate=0.
)
self.decoder = TransformerDecoder(
dim=DIM,
depth=dec_depth,
n_heads=8,
head_dim=dec_head_dim,
mlp_dim=MLP_DIM,
dropout_rate=0.,
apply_softmax=True
)
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
self.conv_out = nn.Sequential(
Conv3x3(EBD_DIM, EBD_DIM, norm=True, act=True),
Conv3x3(EBD_DIM, num_classes)
)
def _get_semantic_tokens(self, x):
b, c = paddle.shape(x)[:2]
att_map = self.conv_att(x)
att_map = att_map.reshape((b,self.token_len,1,-1))
att_map = F.softmax(att_map, axis=-1)
x = x.reshape((b,1,c,-1))
tokens = (x*att_map).sum(-1)
return tokens
def _get_reshaped_tokens(self, x):
if self.pool_mode == 'max':
x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size))
elif self.pool_mode == 'avg':
x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size))
else:
x = x
tokens = x.transpose((0,2,3,1)).flatten(1,2)
return tokens
def encode(self, x):
if self.enc_with_pos:
x += self.enc_pos_embedding
x = self.encoder(x)
return x
def decode(self, x, m):
b, c, h, w = paddle.shape(x)
x = x.transpose((0,2,3,1)).flatten(1,2)
x = self.decoder(x, m)
x = x.transpose((0,2,1)).reshape((b,c,h,w))
return x
def forward(self, t1, t2):
# Extract features via shared backbone.
x1 = self.backbone(t1)
x2 = self.backbone(t2)
# Tokenization
if self.use_tokenizer:
token1 = self._get_semantic_tokens(x1)
token2 = self._get_semantic_tokens(x2)
else:
token1 = self._get_reshaped_tokens(x1)
token2 = self._get_reshaped_tokens(x2)
# Transformer encoder forward
token = paddle.concat([token1, token2], axis=1)
token = self.encode(token)
token1, token2 = paddle.chunk(token, 2, axis=1)
# Transformer decoder forward
y1 = self.decode(x1, token1)
y2 = self.decode(x2, token2)
# Feature differencing
y = paddle.abs(y1 - y2)
y = self.upsample(y)
# Classifier forward
pred = self.conv_out(y)
return pred,
def init_weight(self):
# Use the default initialization method.
pass
class Residual(nn.Layer):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class Residual2(nn.Layer):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x1, x2, **kwargs):
return self.fn(x1, x2, **kwargs) + x1
class PreNorm(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class PreNorm2(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x1, x2, **kwargs):
return self.fn(self.norm(x1), self.norm(x2), **kwargs)
class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout_rate=0.):
super().__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout_rate)
)
class CrossAttention(nn.Layer):
def __init__(self, dim, n_heads=8, head_dim=64, dropout_rate=0., apply_softmax=True):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.scale = dim ** -0.5
self.apply_softmax = apply_softmax
self.fc_q = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_k = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_v = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout_rate)
)
def forward(self, x, ref):
b, n = paddle.shape(x)[:2]
h = self.n_heads
q = self.fc_q(x)
k = self.fc_k(ref)
v = self.fc_v(ref)
q = q.reshape((b,n,h,-1)).transpose((0,2,1,3))
k = k.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
v = v.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
mult = paddle.matmul(q, k, transpose_y=True) * self.scale
if self.apply_softmax:
mult = F.softmax(mult, axis=-1)
out = paddle.matmul(mult, v)
out = out.transpose((0,2,1,3)).flatten(2)
return self.fc_out(out)
class SelfAttention(CrossAttention):
def forward(self, x):
return super().forward(x, x)
class TransformerEncoder(nn.Layer):
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate):
super().__init__()
self.layers = nn.LayerList([])
for _ in range(depth):
self.layers.append(nn.LayerList([
Residual(PreNorm(dim, SelfAttention(dim, n_heads, head_dim, dropout_rate))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
]))
def forward(self, x):
for att, ff in self.layers:
x = att(x)
x = ff(x)
return x
class TransformerDecoder(nn.Layer):
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate, apply_softmax=True):
super().__init__()
self.layers = nn.LayerList([])
for _ in range(depth):
self.layers.append(nn.LayerList([
Residual2(PreNorm2(dim, CrossAttention(dim, n_heads, head_dim, dropout_rate, apply_softmax))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
]))
def forward(self, x, m):
for att, ff in self.layers:
x = att(x, m)
x = ff(x)
return x
class Backbone(nn.Layer, KaimingInitMixin):
def __init__(
self,
in_ch, out_ch=32,
arch='resnet18',
pretrained=True,
n_stages=5
):
super().__init__()
expand = 1
strides = (2,1,2,1,1)
if arch == 'resnet18':
self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet34':
self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
else:
raise ValueError
self.n_stages = n_stages
if self.n_stages == 5:
itm_ch = 512 * expand
elif self.n_stages == 4:
itm_ch = 256 * expand
elif self.n_stages == 3:
itm_ch = 128 * expand
else:
raise ValueError
self.upsample = nn.Upsample(scale_factor=2)
self.conv_out = Conv3x3(itm_ch, out_ch)
self._trim_resnet()
if in_ch != 3:
self.resnet.conv1 = nn.Conv2D(
in_ch,
64,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False
)
if not pretrained:
self.init_weight()
def forward(self, x):
y = self.resnet.conv1(x)
y = self.resnet.bn1(y)
y = self.resnet.relu(y)
y = self.resnet.maxpool(y)
y = self.resnet.layer1(y)
y = self.resnet.layer2(y)
y = self.resnet.layer3(y)
y = self.resnet.layer4(y)
y = self.upsample(y)
return self.conv_out(y)
def _trim_resnet(self):
if self.n_stages > 5:
raise ValueError
if self.n_stages < 5:
self.resnet.layer4 = Identity()
if self.n_stages <= 3:
self.resnet.layer3 = Identity()
self.resnet.avgpool = Identity()
self.resnet.fc = Identity()

@ -0,0 +1,96 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import make_norm, Conv3x3, CBAM
from .stanet import Backbone, Decoder
class DSAMNet(nn.Layer):
"""
The DSAMNet implementation based on PaddlePaddle.
The original article refers to
Q. Shi, et al., "A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing
Change Detection"
(https://ieeexplore.ieee.org/document/9467555).
Note that this implementation differs from the original work in two aspects:
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
2. A classification head is used in place of the original metric learning-based head to stablize the training process.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
ca_ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8.
sa_kernel (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7.
"""
def __init__(self, in_channels, num_classes, ca_ratio=8, sa_kernel=7):
super().__init__()
WIDTH = 64
self.backbone = Backbone(in_ch=in_channels, arch='resnet18', strides=(1,1,2,2,1))
self.decoder = Decoder(WIDTH)
self.cbam1 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel)
self.cbam2 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel)
self.dsl2 = DSLayer(64, num_classes, 32, stride=2, output_padding=1)
self.dsl3 = DSLayer(128, num_classes, 32, stride=4, output_padding=3)
self.conv_out = nn.Sequential(
Conv3x3(WIDTH, WIDTH, norm=True, act=True),
Conv3x3(WIDTH, num_classes)
)
self.init_weight()
def forward(self, t1, t2):
f1 = self.backbone(t1)
f2 = self.backbone(t2)
y1 = self.decoder(f1)
y2 = self.decoder(f2)
y1 = self.cbam1(y1)
y2 = self.cbam2(y2)
out = paddle.abs(y1-y2)
out = F.interpolate(out, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
pred = self.conv_out(out)
ds2 = self.dsl2(paddle.abs(f1[0]-f2[0]))
ds3 = self.dsl3(paddle.abs(f1[1]-f2[1]))
return pred, ds2, ds3
def init_weight(self):
pass
class DSLayer(nn.Sequential):
def __init__(self, in_ch, out_ch, itm_ch, **convd_kwargs):
super().__init__(
nn.Conv2DTranspose(in_ch, itm_ch, kernel_size=3, padding=1, **convd_kwargs),
make_norm(itm_ch),
nn.ReLU(),
nn.Dropout2D(p=0.2),
nn.Conv2DTranspose(itm_ch, out_ch, kernel_size=3, padding=1)
)

@ -0,0 +1,209 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg16
from .layers import Conv1x1, make_norm, ChannelAttention, SpatialAttention
class DSIFN(nn.Layer):
"""
The DSIFN implementation based on PaddlePaddle.
The original article refers to
C. Zhang, et al., "A deeply supervised image fusion network for change detection in high resolution bi-temporal remote
sensing images"
(https://www.sciencedirect.com/science/article/pii/S0924271620301532).
Note that in this implementation, there is a flexible number of target classes.
Args:
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(self, num_classes, use_dropout=False):
super().__init__()
self.encoder1 = self.encoder2 = VGG16FeaturePicker()
self.sa1 = SpatialAttention()
self.sa2= SpatialAttention()
self.sa3 = SpatialAttention()
self.sa4 = SpatialAttention()
self.sa5 = SpatialAttention()
self.ca1 = ChannelAttention(in_ch=1024)
self.bn_ca1 = make_norm(1024)
self.o1_conv1 = conv2d_bn(1024, 512, use_dropout)
self.o1_conv2 = conv2d_bn(512, 512, use_dropout)
self.bn_sa1 = make_norm(512)
self.o1_conv3 = Conv1x1(512, num_classes)
self.trans_conv1 = nn.Conv2DTranspose(512, 512, kernel_size=2, stride=2)
self.ca2 = ChannelAttention(in_ch=1536)
self.bn_ca2 = make_norm(1536)
self.o2_conv1 = conv2d_bn(1536, 512, use_dropout)
self.o2_conv2 = conv2d_bn(512, 256, use_dropout)
self.o2_conv3 = conv2d_bn(256, 256, use_dropout)
self.bn_sa2 = make_norm(256)
self.o2_conv4 = Conv1x1(256, num_classes)
self.trans_conv2 = nn.Conv2DTranspose(256, 256, kernel_size=2, stride=2)
self.ca3 = ChannelAttention(in_ch=768)
self.o3_conv1 = conv2d_bn(768, 256, use_dropout)
self.o3_conv2 = conv2d_bn(256, 128, use_dropout)
self.o3_conv3 = conv2d_bn(128, 128, use_dropout)
self.bn_sa3 = make_norm(128)
self.o3_conv4 = Conv1x1(128, num_classes)
self.trans_conv3 = nn.Conv2DTranspose(128, 128, kernel_size=2, stride=2)
self.ca4 = ChannelAttention(in_ch=384)
self.o4_conv1 = conv2d_bn(384, 128, use_dropout)
self.o4_conv2 = conv2d_bn(128, 64, use_dropout)
self.o4_conv3 = conv2d_bn(64, 64, use_dropout)
self.bn_sa4 = make_norm(64)
self.o4_conv4 = Conv1x1(64, num_classes)
self.trans_conv4 = nn.Conv2DTranspose(64, 64, kernel_size=2, stride=2)
self.ca5 = ChannelAttention(in_ch=192)
self.o5_conv1 = conv2d_bn(192, 64, use_dropout)
self.o5_conv2 = conv2d_bn(64, 32, use_dropout)
self.o5_conv3 = conv2d_bn(32, 16, use_dropout)
self.bn_sa5 = make_norm(16)
self.o5_conv4 = Conv1x1(16, num_classes)
self.init_weight()
def forward(self, t1, t2):
# Extract bi-temporal features.
with paddle.no_grad():
self.encoder1.eval(), self.encoder2.eval()
t1_feats = self.encoder1(t1)
t2_feats = self.encoder2(t2)
t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_feats
t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29,= t2_feats
# Multi-level decoding
x = paddle.concat([t1_f_l29, t2_f_l29], axis=1)
x = self.o1_conv1(x)
x = self.o1_conv2(x)
x = self.sa1(x) * x
x = self.bn_sa1(x)
out1 = F.interpolate(
self.o1_conv3(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv1(x)
x = paddle.concat([x, t1_f_l22, t2_f_l22], axis=1)
x = self.ca2(x)*x
x = self.o2_conv1(x)
x = self.o2_conv2(x)
x = self.o2_conv3(x)
x = self.sa2(x) *x
x = self.bn_sa2(x)
out2 = F.interpolate(
self.o2_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv2(x)
x = paddle.concat([x, t1_f_l15, t2_f_l15], axis=1)
x = self.ca3(x)*x
x = self.o3_conv1(x)
x = self.o3_conv2(x)
x = self.o3_conv3(x)
x = self.sa3(x) *x
x = self.bn_sa3(x)
out3 = F.interpolate(
self.o3_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv3(x)
x = paddle.concat([x, t1_f_l8, t2_f_l8], axis=1)
x = self.ca4(x)*x
x = self.o4_conv1(x)
x = self.o4_conv2(x)
x = self.o4_conv3(x)
x = self.sa4(x) *x
x = self.bn_sa4(x)
out4 = F.interpolate(
self.o4_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv4(x)
x = paddle.concat([x, t1_f_l3, t2_f_l3], axis=1)
x = self.ca5(x)*x
x = self.o5_conv1(x)
x = self.o5_conv2(x)
x = self.o5_conv3(x)
x = self.sa5(x) *x
x = self.bn_sa5(x)
out5 = self.o5_conv4(x)
return out5, out4, out3, out2, out1
def init_weight(self):
# Do nothing
pass
class VGG16FeaturePicker(nn.Layer):
def __init__(self, indices=(3,8,15,22,29)):
super().__init__()
features = list(vgg16(pretrained=True).features)[:30]
self.features = nn.LayerList(features)
self.features.eval()
self.indices = set(indices)
def forward(self, x):
picked_feats = []
for idx, model in enumerate(self.features):
x = model(x)
if idx in self.indices:
picked_feats.append(x)
return picked_feats
def conv2d_bn(in_ch, out_ch, with_dropout=True):
lst = [
nn.Conv2D(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
nn.PReLU(),
make_norm(out_ch),
]
if with_dropout:
lst.append(nn.Dropout(p=0.6))
return nn.Sequential(*lst)

@ -0,0 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .blocks import *
from .attention import ChannelAttention, SpatialAttention, CBAM

@ -0,0 +1,96 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .blocks import Conv1x1, BasicConv
class ChannelAttention(nn.Layer):
"""
The channel attention module implementation based on PaddlePaddle.
The original article refers to
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module"
(https://arxiv.org/abs/1807.06521).
Args:
in_ch (int): The number of channels of the input features.
ratio (int, optional): The channel reduction ratio. Default: 8.
"""
def __init__(self, in_ch, ratio=8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.max_pool = nn.AdaptiveMaxPool2D(1)
self.fc1 = Conv1x1(in_ch, in_ch//ratio, bias=False, act=True)
self.fc2 = Conv1x1(in_ch//ratio, in_ch, bias=False)
def forward(self,x):
avg_out = self.fc2(self.fc1(self.avg_pool(x)))
max_out = self.fc2(self.fc1(self.max_pool(x)))
out = avg_out + max_out
return F.sigmoid(out)
class SpatialAttention(nn.Layer):
"""
The spatial attention module implementation based on PaddlePaddle.
The original article refers to
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module"
(https://arxiv.org/abs/1807.06521).
Args:
kernel_size (int, optional): The size of the convolutional kernel. Default: 7.
"""
def __init__(self, kernel_size=7):
super().__init__()
self.conv = BasicConv(2, 1, kernel_size, bias=False)
def forward(self, x):
avg_out = paddle.mean(x, axis=1, keepdim=True)
max_out = paddle.max(x, axis=1, keepdim=True)
x = paddle.concat([avg_out, max_out], axis=1)
x = self.conv(x)
return F.sigmoid(x)
class CBAM(nn.Layer):
"""
The CBAM implementation based on PaddlePaddle.
The original article refers to
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module"
(https://arxiv.org/abs/1807.06521).
Args:
in_ch (int): The number of channels of the input features.
ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8.
kernel_size (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7.
"""
def __init__(self, in_ch, ratio=8, kernel_size=7):
super().__init__()
self.ca = ChannelAttention(in_ch, ratio=ratio)
self.sa = SpatialAttention(kernel_size=kernel_size)
def forward(self, x):
y = self.ca(x) * x
y = self.sa(y) * y
return y

@ -0,0 +1,142 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
__all__ = [
'BasicConv', 'Conv1x1', 'Conv3x3', 'Conv7x7',
'MaxPool2x2', 'MaxUnPool2x2',
'ConvTransposed3x3',
'Identity',
'get_norm_layer', 'get_act_layer',
'make_norm', 'make_act'
]
def get_norm_layer():
# TODO: select appropriate norm layer.
return nn.BatchNorm2D
def get_act_layer():
# TODO: select appropriate activation layer.
return nn.ReLU
def make_norm(*args, **kwargs):
norm_layer = get_norm_layer()
return norm_layer(*args, **kwargs)
def make_act(*args, **kwargs):
act_layer = get_act_layer()
return act_layer(*args, **kwargs)
class BasicConv(nn.Layer):
def __init__(
self, in_ch, out_ch,
kernel_size, pad_mode='constant',
bias='auto', norm=False, act=False,
**kwargs
):
super().__init__()
seq = []
if kernel_size >= 2:
seq.append(nn.Pad2D(kernel_size//2, mode=pad_mode))
seq.append(
nn.Conv2D(
in_ch, out_ch, kernel_size,
stride=1, padding=0,
bias_attr=(False if norm else None) if bias=='auto' else bias,
**kwargs
)
)
if norm:
if norm is True:
norm = make_norm(out_ch)
seq.append(norm)
if act:
if act is True:
act = make_act()
seq.append(act)
self.seq = nn.Sequential(*seq)
def forward(self, x):
return self.seq(x)
class Conv1x1(BasicConv):
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs):
super().__init__(in_ch, out_ch, 1, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs)
class Conv3x3(BasicConv):
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs):
super().__init__(in_ch, out_ch, 3, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs)
class Conv7x7(BasicConv):
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs):
super().__init__(in_ch, out_ch, 7, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs)
class MaxPool2x2(nn.MaxPool2D):
def __init__(self, **kwargs):
super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs)
class MaxUnPool2x2(nn.MaxUnPool2D):
def __init__(self, **kwargs):
super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs)
class ConvTransposed3x3(nn.Layer):
def __init__(
self, in_ch, out_ch,
bias='auto', norm=False, act=False,
**kwargs
):
super().__init__()
seq = []
seq.append(
nn.Conv2DTranspose(
in_ch, out_ch, 3,
stride=2, padding=1,
bias_attr=(False if norm else None) if bias=='auto' else bias,
**kwargs
)
)
if norm:
if norm is True:
norm = make_norm(out_ch)
seq.append(norm)
if act:
if act is True:
act = make_act()
seq.append(act)
self.seq = nn.Sequential(*seq)
def forward(self, x):
return self.seq(x)
class Identity(nn.Layer):
"""A placeholder identity operator that accepts exactly one argument."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
return x

@ -0,0 +1,86 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
def normal_init(param, *args, **kwargs):
"""
Initialize parameters with a normal distribution.
Args:
param (Tensor): The tensor that needs to be initialized.
Returns:
The initialized parameters.
"""
return nn.initializer.Normal(*args, **kwargs)(param)
def kaiming_normal_init(param, *args, **kwargs):
"""
Initialize parameters with the Kaiming normal distribution.
For more information about the Kaiming initialization method, please refer to
https://arxiv.org/abs/1502.01852
Args:
param (Tensor): The tensor that needs to be initialized.
Returns:
The initialized parameters.
"""
return nn.initializer.KaimingNormal(*args, **kwargs)(param)
def constant_init(param, *args, **kwargs):
"""
Initialize parameters with constants.
Args:
param (Tensor): The tensor that needs to be initialized.
Returns:
The initialized parameters.
"""
return nn.initializer.Constant(*args, **kwargs)(param)
class KaimingInitMixin:
"""
A mix-in that provides the Kaiming initialization functionality.
Examples:
from paddlers.models.cd.models.param_init import KaimingInitMixin
class CustomNet(nn.Layer, KaimingInitMixin):
def __init__(self, num_channels, num_classes):
super().__init__()
self.conv = nn.Conv2D(num_channels, num_classes, 3, 1, 0, bias_attr=False)
self.bn = nn.BatchNorm2D(num_classes)
self.init_weight()
"""
def init_weight(self):
for layer in self.sublayers():
if isinstance(layer, nn.Conv2D):
kaiming_normal_init(layer.weight)
elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(layer.weight, value=1)
constant_init(layer.bias, value=0)

@ -0,0 +1,155 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv1x1, MaxPool2x2, make_norm, ChannelAttention
from .param_init import KaimingInitMixin
class SNUNet(nn.Layer, KaimingInitMixin):
"""
The SNUNet implementation based on PaddlePaddle.
The original article refers to
S. Fang, et al., "SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images"
(https://ieeexplore.ieee.org/document/9355573).
Note that bilinear interpolation is adopted as the upsampling method, which is different from the paper.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
width (int, optional): The output channels of the first convolutional layer. Default: 32.
"""
def __init__(self, in_channels, num_classes, width=32):
super().__init__()
filters = (width, width*2, width*4, width*8, width*16)
self.conv0_0 = ConvBlockNested(in_channels, filters[0], filters[0])
self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1])
self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2])
self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3])
self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4])
self.down1 = MaxPool2x2()
self.down2 = MaxPool2x2()
self.down3 = MaxPool2x2()
self.down4 = MaxPool2x2()
self.up1_0 = Up(filters[1])
self.up2_0 = Up(filters[2])
self.up3_0 = Up(filters[3])
self.up4_0 = Up(filters[4])
self.conv0_1 = ConvBlockNested(filters[0]*2+filters[1], filters[0], filters[0])
self.conv1_1 = ConvBlockNested(filters[1]*2+filters[2], filters[1], filters[1])
self.conv2_1 = ConvBlockNested(filters[2]*2+filters[3], filters[2], filters[2])
self.conv3_1 = ConvBlockNested(filters[3]*2+filters[4], filters[3], filters[3])
self.up1_1 = Up(filters[1])
self.up2_1 = Up(filters[2])
self.up3_1 = Up(filters[3])
self.conv0_2 = ConvBlockNested(filters[0]*3+filters[1], filters[0], filters[0])
self.conv1_2 = ConvBlockNested(filters[1]*3+filters[2], filters[1], filters[1])
self.conv2_2 = ConvBlockNested(filters[2]*3+filters[3], filters[2], filters[2])
self.up1_2 = Up(filters[1])
self.up2_2 = Up(filters[2])
self.conv0_3 = ConvBlockNested(filters[0]*4+filters[1], filters[0], filters[0])
self.conv1_3 = ConvBlockNested(filters[1]*4+filters[2], filters[1], filters[1])
self.up1_3 = Up(filters[1])
self.conv0_4 = ConvBlockNested(filters[0]*5+filters[1], filters[0], filters[0])
self.ca_intra = ChannelAttention(filters[0], ratio=4)
self.ca_inter = ChannelAttention(filters[0]*4, ratio=16)
self.conv_out = Conv1x1(filters[0]*4, num_classes)
self.init_weight()
def forward(self, t1, t2):
x0_0_t1 = self.conv0_0(t1)
x1_0_t1 = self.conv1_0(self.down1(x0_0_t1))
x2_0_t1 = self.conv2_0(self.down2(x1_0_t1))
x3_0_t1 = self.conv3_0(self.down3(x2_0_t1))
x0_0_t2 = self.conv0_0(t2)
x1_0_t2 = self.conv1_0(self.down1(x0_0_t2))
x2_0_t2 = self.conv2_0(self.down2(x1_0_t2))
x3_0_t2 = self.conv3_0(self.down3(x2_0_t2))
x4_0_t2 = self.conv4_0(self.down4(x3_0_t2))
x0_1 = self.conv0_1(paddle.concat([x0_0_t1, x0_0_t2, self.up1_0(x1_0_t2)], 1))
x1_1 = self.conv1_1(paddle.concat([x1_0_t1, x1_0_t2, self.up2_0(x2_0_t2)], 1))
x0_2 = self.conv0_2(paddle.concat([x0_0_t1, x0_0_t2, x0_1, self.up1_1(x1_1)], 1))
x2_1 = self.conv2_1(paddle.concat([x2_0_t1, x2_0_t2, self.up3_0(x3_0_t2)], 1))
x1_2 = self.conv1_2(paddle.concat([x1_0_t1, x1_0_t2, x1_1, self.up2_1(x2_1)], 1))
x0_3 = self.conv0_3(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, self.up1_2(x1_2)], 1))
x3_1 = self.conv3_1(paddle.concat([x3_0_t1, x3_0_t2, self.up4_0(x4_0_t2)], 1))
x2_2 = self.conv2_2(paddle.concat([x2_0_t1, x2_0_t2, x2_1, self.up3_1(x3_1)], 1))
x1_3 = self.conv1_3(paddle.concat([x1_0_t1, x1_0_t2, x1_1, x1_2, self.up2_2(x2_2)], 1))
x0_4 = self.conv0_4(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1))
out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1)
intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0)
m_intra = self.ca_intra(intra)
out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1,4,1,1)))
pred = self.conv_out(out)
return pred,
class ConvBlockNested(nn.Layer):
def __init__(self, in_ch, out_ch, mid_ch):
super().__init__()
self.act = nn.ReLU()
self.conv1 = nn.Conv2D(in_ch, mid_ch, kernel_size=3, padding=1)
self.bn1 = make_norm(mid_ch)
self.conv2 = nn.Conv2D(mid_ch, out_ch, kernel_size=3, padding=1)
self.bn2 = make_norm(out_ch)
def forward(self, x):
x = self.conv1(x)
identity = x
x = self.bn1(x)
x = self.act(x)
x = self.conv2(x)
x = self.bn2(x)
output = self.act(x + identity)
return output
class Up(nn.Layer):
def __init__(self, in_ch, use_conv=False):
super().__init__()
if use_conv:
self.up = nn.Conv2DTranspose(in_ch, in_ch, 2, stride=2)
else:
self.up = nn.Upsample(scale_factor=2,
mode='bilinear',
align_corners=True)
def forward(self, x):
x = self.up(x)
return x

@ -0,0 +1,298 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .backbones import resnet
from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
from .param_init import KaimingInitMixin
class STANet(nn.Layer):
"""
The STANet implementation based on PaddlePaddle.
The original article refers to
H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"
(https://www.mdpi.com/2072-4292/12/10/1662).
Note that this implementation differs from the original work in two aspects:
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
2. A classification head is used in place of the original metric learning-based head to stablize the training process.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.
ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values
greater than 1, the input features will first be processed by an average pooling layer with the kernel size of
`ds_factor`, before being used to calculate the attention scores. Default: 1.
Raises:
ValueError: When `att_type` has an illeagal value (unsupported attention type).
"""
def __init__(
self,
in_channels,
num_classes,
att_type='BAM',
ds_factor=1
):
super().__init__()
WIDTH = 64
self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)
self.attend = build_sta_module(in_ch=WIDTH, att_type=att_type, ds=ds_factor)
self.conv_out = nn.Sequential(
Conv3x3(WIDTH, WIDTH, norm=True, act=True),
Conv3x3(WIDTH, num_classes)
)
self.init_weight()
def forward(self, t1, t2):
f1 = self.extract(t1)
f2 = self.extract(t2)
f1, f2 = self.attend(f1, f2)
y = paddle.abs(f1- f2)
y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
pred = self.conv_out(y)
return pred,
def init_weight(self):
# Do nothing here as the encoder and decoder weights have already been initialized.
# Note however that currently self.attend and self.conv_out use the default initilization method.
pass
def build_feat_extractor(in_ch, width):
return nn.Sequential(
Backbone(in_ch, 'resnet18'),
Decoder(width)
)
def build_sta_module(in_ch, att_type, ds):
if att_type == 'BAM':
return Attention(BAM(in_ch, ds))
elif att_type == 'PAM':
return Attention(PAM(in_ch, ds))
else:
raise ValueError
class Backbone(nn.Layer, KaimingInitMixin):
def __init__(self, in_ch, arch, pretrained=True, strides=(2,1,2,2,2)):
super().__init__()
if arch == 'resnet18':
self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet34':
self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet50':
self.resnet = resnet.resnet50(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
else:
raise ValueError
self._trim_resnet()
if in_ch != 3:
self.resnet.conv1 = nn.Conv2D(
in_ch,
64,
kernel_size=7,
stride=strides[0],
padding=3,
bias_attr=False
)
if not pretrained:
self.init_weight()
def forward(self, x):
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x1 = self.resnet.layer1(x)
x2 = self.resnet.layer2(x1)
x3 = self.resnet.layer3(x2)
x4 = self.resnet.layer4(x3)
return x1, x2, x3, x4
def _trim_resnet(self):
self.resnet.avgpool = Identity()
self.resnet.fc = Identity()
class Decoder(nn.Layer, KaimingInitMixin):
def __init__(self, f_ch):
super().__init__()
self.dr1 = Conv1x1(64, 96, norm=True, act=True)
self.dr2 = Conv1x1(128, 96, norm=True, act=True)
self.dr3 = Conv1x1(256, 96, norm=True, act=True)
self.dr4 = Conv1x1(512, 96, norm=True, act=True)
self.conv_out = nn.Sequential(
Conv3x3(384, 256, norm=True, act=True),
nn.Dropout(0.5),
Conv1x1(256, f_ch, norm=True, act=True)
)
self.init_weight()
def forward(self, feats):
f1 = self.dr1(feats[0])
f2 = self.dr2(feats[1])
f3 = self.dr3(feats[2])
f4 = self.dr4(feats[3])
f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
x = paddle.concat([f1, f2, f3, f4], axis=1)
y = self.conv_out(x)
return y
class BAM(nn.Layer):
def __init__(self, in_ch, ds):
super().__init__()
self.ds = ds
self.pool = nn.AvgPool2D(self.ds)
self.val_ch = in_ch
self.key_ch = in_ch // 8
self.conv_q = Conv1x1(in_ch, self.key_ch)
self.conv_k = Conv1x1(in_ch, self.key_ch)
self.conv_v = Conv1x1(in_ch, self.val_ch)
self.softmax = nn.Softmax(axis=-1)
def forward(self, x):
x = x.flatten(-2)
x_rs = self.pool(x)
b, c, h, w = paddle.shape(x_rs)
query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
key = self.conv_k(x_rs).reshape((b,-1,h*w))
energy = paddle.bmm(query, key)
energy = (self.key_ch**(-0.5)) * energy
attention = self.softmax(energy)
value = self.conv_v(x_rs).reshape((b,-1,w*h))
out = paddle.bmm(value, attention.transpose((0,2,1)))
out = out.reshape((b,c,h,w))
out = F.interpolate(out, scale_factor=self.ds)
out = out + x
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
class PAMBlock(nn.Layer):
def __init__(self, in_ch, scale=1, ds=1):
super().__init__()
self.scale = scale
self.ds = ds
self.pool = nn.AvgPool2D(self.ds)
self.val_ch = in_ch
self.key_ch = in_ch // 8
self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)
self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)
self.conv_v = Conv1x1(in_ch, self.val_ch)
def forward(self, x):
x_rs = self.pool(x)
# Get query, key, and value.
query = self.conv_q(x_rs)
key = self.conv_k(x_rs)
value = self.conv_v(x_rs)
# Split the whole image into subregions.
b, c, h, w = paddle.shape(x_rs)
query = self._split_subregions(query)
key = self._split_subregions(key)
value = self._split_subregions(value)
# Perform subregion-wise attention.
out = self._attend(query, key, value)
# Stack subregions to reconstruct the whole image.
out = self._recons_whole(out, b, c, h, w)
out = F.interpolate(out, scale_factor=self.ds)
return out
def _attend(self, query, key, value):
energy = paddle.bmm(query.transpose((0,2,1)), key) # batch matrix multiplication
energy = (self.key_ch**(-0.5)) * energy
attention = F.softmax(energy, axis=-1)
out = paddle.bmm(value, attention.transpose((0,2,1)))
return out
def _split_subregions(self, x):
b, c, h, w = paddle.shape(x)
assert h % self.scale == 0 and w % self.scale == 0
x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))
return x
def _recons_whole(self, x, b, c, h, w):
x = x.reshape((b, self.scale, self.scale, c, h//self.scale, w//self.scale))
x = x.transpose((0,3,1,4,2,5)).reshape((b, c, h, w))
return x
class PAM(nn.Layer):
def __init__(self, in_ch, ds, scales=(1,2,4,8)):
super().__init__()
self.stages = nn.LayerList([
PAMBlock(in_ch, scale=s, ds=ds)
for s in scales
])
self.conv_out = Conv1x1(in_ch*len(scales), in_ch, bias=False)
def forward(self, x):
x = x.flatten(-2)
res = [stage(x) for stage in self.stages]
out = self.conv_out(paddle.concat(res, axis=1))
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
class Attention(nn.Layer):
def __init__(self, att):
super().__init__()
self.att = att
def forward(self, x1, x2):
x = paddle.stack([x1, x2], axis=-1)
y = self.att(x)
return y[...,0], y[...,1]

@ -0,0 +1,201 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
from .param_init import normal_init, constant_init
class UNetEarlyFusion(nn.Layer):
"""
The FC-EF implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(
self,
in_channels,
num_classes,
use_dropout=False
):
super().__init__()
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
self.use_dropout = use_dropout
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
self.do11 = self._make_dropout()
self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
self.do12 = self._make_dropout()
self.pool1 = MaxPool2x2()
self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
self.do21 = self._make_dropout()
self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
self.do22 = self._make_dropout()
self.pool2 = MaxPool2x2()
self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
self.do31 = self._make_dropout()
self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
self.do32 = self._make_dropout()
self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
self.do33 = self._make_dropout()
self.pool3 = MaxPool2x2()
self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
self.do41 = self._make_dropout()
self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
self.do42 = self._make_dropout()
self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
self.do43 = self._make_dropout()
self.pool4 = MaxPool2x2()
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
self.conv43d = Conv3x3(C5, C4, norm=True, act=True)
self.do43d = self._make_dropout()
self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
self.do42d = self._make_dropout()
self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
self.do41d = self._make_dropout()
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
self.conv33d = Conv3x3(C4, C3, norm=True, act=True)
self.do33d = self._make_dropout()
self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
self.do32d = self._make_dropout()
self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
self.do31d = self._make_dropout()
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
self.conv22d = Conv3x3(C3, C2, norm=True, act=True)
self.do22d = self._make_dropout()
self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
self.do21d = self._make_dropout()
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
self.conv12d = Conv3x3(C2, C1, norm=True, act=True)
self.do12d = self._make_dropout()
self.conv11d = Conv3x3(C1, num_classes)
self.init_weight()
def forward(self, t1, t2):
x = paddle.concat([t1, t2], axis=1)
# Stage 1
x11 = self.do11(self.conv11(x))
x12 = self.do12(self.conv12(x11))
x1p = self.pool1(x12)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22 = self.do22(self.conv22(x21))
x2p = self.pool2(x22)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33 = self.do33(self.conv33(x32))
x3p = self.pool3(x33)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43 = self.do43(self.conv43(x42))
x4p = self.pool4(x43)
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (
0,
paddle.shape(x43)[3]-paddle.shape(x4d)[3],
0,
paddle.shape(x43)[2]-paddle.shape(x4d)[2]
)
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1)
x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d))
x41d = self.do41d(self.conv41d(x42d))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (
0,
paddle.shape(x33)[3]-paddle.shape(x3d)[3],
0,
paddle.shape(x33)[2]-paddle.shape(x3d)[2]
)
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1)
x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d))
x31d = self.do31d(self.conv31d(x32d))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (
0,
paddle.shape(x22)[3]-paddle.shape(x2d)[3],
0,
paddle.shape(x22)[2]-paddle.shape(x2d)[2]
)
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1)
x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (
0,
paddle.shape(x12)[3]-paddle.shape(x1d)[3],
0,
paddle.shape(x12)[2]-paddle.shape(x1d)[2]
)
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1)
x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d)
return x11d,
def init_weight(self):
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2D):
normal_init(sublayer.weight, std=0.001)
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(sublayer.weight, value=1.0)
constant_init(sublayer.bias, value=0.0)
def _make_dropout(self):
if self.use_dropout:
return nn.Dropout2D(p=0.2)
else:
return Identity()

@ -0,0 +1,224 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
from .param_init import normal_init, constant_init
class UNetSiamConc(nn.Layer):
"""
The FC-Siam-conc implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(
self,
in_channels,
num_classes,
use_dropout=False
):
super().__init__()
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
self.use_dropout = use_dropout
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
self.do11 = self._make_dropout()
self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
self.do12 = self._make_dropout()
self.pool1 = MaxPool2x2()
self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
self.do21 = self._make_dropout()
self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
self.do22 = self._make_dropout()
self.pool2 = MaxPool2x2()
self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
self.do31 = self._make_dropout()
self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
self.do32 = self._make_dropout()
self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
self.do33 = self._make_dropout()
self.pool3 = MaxPool2x2()
self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
self.do41 = self._make_dropout()
self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
self.do42 = self._make_dropout()
self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
self.do43 = self._make_dropout()
self.pool4 = MaxPool2x2()
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
self.conv43d = Conv3x3(C5+C4, C4, norm=True, act=True)
self.do43d = self._make_dropout()
self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
self.do42d = self._make_dropout()
self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
self.do41d = self._make_dropout()
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
self.conv33d = Conv3x3(C4+C3, C3, norm=True, act=True)
self.do33d = self._make_dropout()
self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
self.do32d = self._make_dropout()
self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
self.do31d = self._make_dropout()
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
self.conv22d = Conv3x3(C3+C2, C2, norm=True, act=True)
self.do22d = self._make_dropout()
self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
self.do21d = self._make_dropout()
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
self.conv12d = Conv3x3(C2+C1, C1, norm=True, act=True)
self.do12d = self._make_dropout()
self.conv11d = Conv3x3(C1, num_classes)
self.init_weight()
def forward(self, t1, t2):
# Encode t1
# Stage 1
x11 = self.do11(self.conv11(t1))
x12_1 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_1)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_1 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_1)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_1 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_1)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_1 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_1)
# Encode t2
# Stage 1
x11 = self.do11(self.conv11(t2))
x12_2 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_2)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_2 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_2)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_2 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_2)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_2 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_2)
# Decode
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (
0,
paddle.shape(x43_1)[3]-paddle.shape(x4d)[3],
0,
paddle.shape(x43_1)[2]-paddle.shape(x4d)[2]
)
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1)
x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d))
x41d = self.do41d(self.conv41d(x42d))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (
0,
paddle.shape(x33_1)[3]-paddle.shape(x3d)[3],
0,
paddle.shape(x33_1)[2]-paddle.shape(x3d)[2]
)
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1)
x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d))
x31d = self.do31d(self.conv31d(x32d))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (
0,
paddle.shape(x22_1)[3]-paddle.shape(x2d)[3],
0,
paddle.shape(x22_1)[2]-paddle.shape(x2d)[2]
)
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1)
x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (
0,
paddle.shape(x12_1)[3]-paddle.shape(x1d)[3],
0,
paddle.shape(x12_1)[2]-paddle.shape(x1d)[2]
)
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1)
x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d)
return x11d,
def init_weight(self):
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2D):
normal_init(sublayer.weight, std=0.001)
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(sublayer.weight, value=1.0)
constant_init(sublayer.bias, value=0.0)
def _make_dropout(self):
if self.use_dropout:
return nn.Dropout2D(p=0.2)
else:
return Identity()

@ -0,0 +1,224 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
from .param_init import normal_init, constant_init
class UNetSiamDiff(nn.Layer):
"""
The FC-Siam-diff implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(
self,
in_channels,
num_classes,
use_dropout=False
):
super().__init__()
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
self.use_dropout = use_dropout
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
self.do11 = self._make_dropout()
self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
self.do12 = self._make_dropout()
self.pool1 = MaxPool2x2()
self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
self.do21 = self._make_dropout()
self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
self.do22 = self._make_dropout()
self.pool2 = MaxPool2x2()
self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
self.do31 = self._make_dropout()
self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
self.do32 = self._make_dropout()
self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
self.do33 = self._make_dropout()
self.pool3 = MaxPool2x2()
self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
self.do41 = self._make_dropout()
self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
self.do42 = self._make_dropout()
self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
self.do43 = self._make_dropout()
self.pool4 = MaxPool2x2()
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
self.conv43d = Conv3x3(C5, C4, norm=True, act=True)
self.do43d = self._make_dropout()
self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
self.do42d = self._make_dropout()
self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
self.do41d = self._make_dropout()
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
self.conv33d = Conv3x3(C4, C3, norm=True, act=True)
self.do33d = self._make_dropout()
self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
self.do32d = self._make_dropout()
self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
self.do31d = self._make_dropout()
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
self.conv22d = Conv3x3(C3, C2, norm=True, act=True)
self.do22d = self._make_dropout()
self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
self.do21d = self._make_dropout()
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
self.conv12d = Conv3x3(C2, C1, norm=True, act=True)
self.do12d = self._make_dropout()
self.conv11d = Conv3x3(C1, num_classes)
self.init_weight()
def forward(self, t1, t2):
# Encode t1
# Stage 1
x11 = self.do11(self.conv11(t1))
x12_1 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_1)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_1 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_1)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_1 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_1)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_1 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_1)
# Encode t2
# Stage 1
x11 = self.do11(self.conv11(t2))
x12_2 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_2)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_2 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_2)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_2 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_2)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_2 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_2)
# Decode
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (
0,
paddle.shape(x43_1)[3]-paddle.shape(x4d)[3],
0,
paddle.shape(x43_1)[2]-paddle.shape(x4d)[2]
)
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), paddle.abs(x43_1-x43_2)], 1)
x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d))
x41d = self.do41d(self.conv41d(x42d))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (
0,
paddle.shape(x33_1)[3]-paddle.shape(x3d)[3],
0,
paddle.shape(x33_1)[2]-paddle.shape(x3d)[2]
)
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), paddle.abs(x33_1-x33_2)], 1)
x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d))
x31d = self.do31d(self.conv31d(x32d))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (
0,
paddle.shape(x22_1)[3]-paddle.shape(x2d)[3],
0,
paddle.shape(x22_1)[2]-paddle.shape(x2d)[2]
)
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), paddle.abs(x22_1-x22_2)], 1)
x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (
0,
paddle.shape(x12_1)[3]-paddle.shape(x1d)[3],
0,
paddle.shape(x12_1)[2]-paddle.shape(x1d)[2]
)
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), paddle.abs(x12_1-x12_2)], 1)
x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d)
return x11d,
def init_weight(self):
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2D):
normal_init(sublayer.weight, std=0.001)
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(sublayer.weight, value=1.0)
constant_init(sublayer.bias, value=0.0)
def _make_dropout(self):
if self.use_dropout:
return nn.Dropout2D(p=0.2)
else:
return Identity()

@ -31,7 +31,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from paddlers.transforms import ImgDecoder, Resize
import paddlers.models.cd as cd
__all__ = ["CDNet"]
__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT", "SNUNet", "DSIFN", "DSAMNet"]
class BaseChangeDetector(BaseModel):
@ -663,3 +663,190 @@ class CDNet(BaseChangeDetector):
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class UNetEarlyFusion(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=6,
use_dropout=False,
**params):
params.update({
'in_channels': in_channels,
'use_dropout': use_dropout
})
super(UNetEarlyFusion, self).__init__(
model_name='UNetEarlyFusion',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class UNetSiamConc(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
use_dropout=False,
**params):
params.update({
'in_channels': in_channels,
'use_dropout': use_dropout
})
super(UNetSiamConc, self).__init__(
model_name='UNetSiamConc',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class UNetSiamDiff(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
use_dropout=False,
**params):
params.update({
'in_channels': in_channels,
'use_dropout': use_dropout
})
super(UNetSiamDiff, self).__init__(
model_name='UNetSiamDiff',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class STANet(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
att_type='BAM',
ds_factor=1,
**params):
params.update({
'in_channels': in_channels,
'att_type': att_type,
'ds_factor': ds_factor
})
super(STANet, self).__init__(
model_name='STANet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class BIT(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
backbone='resnet18',
n_stages=4,
use_tokenizer=True,
token_len=4,
pool_mode='max',
pool_size=2,
enc_with_pos=True,
enc_depth=1,
enc_head_dim=64,
dec_depth=8,
dec_head_dim=8,
**params):
params.update({
'in_channels': in_channels,
'backbone': backbone,
'n_stages': n_stages,
'use_tokenizer': use_tokenizer,
'token_len': token_len,
'pool_mode': pool_mode,
'pool_size': pool_size,
'enc_with_pos': enc_with_pos,
'enc_depth': enc_depth,
'enc_head_dim': enc_head_dim,
'dec_depth': dec_depth,
'dec_head_dim': dec_head_dim
})
super(BIT, self).__init__(
model_name='BIT',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class SNUNet(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
width=32,
**params):
params.update({
'in_channels': in_channels,
'width': width
})
super(SNUNet, self).__init__(
model_name='SNUNet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class DSIFN(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=None,
use_dropout=False,
**params):
params.update({
'use_dropout': use_dropout
})
super(DSIFN, self).__init__(
model_name='DSIFN',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
# HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are
# constructed automatically.
assert use_mixed_loss is None
if use_mixed_loss is None:
self.losses = {
# XXX: make sure the shallow copy works correctly here.
'types': [paddleseg.models.CrossEntropyLoss()]*5,
'coef': [1.0]*5
}
class DSAMNet(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=None,
in_channels=3,
ca_ratio=8,
sa_kernel=7,
**params):
params.update({
'in_channels': in_channels,
'ca_ratio': ca_ratio,
'sa_kernel': sa_kernel
})
super(DSAMNet, self).__init__(
model_name='DSAMNet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
# HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are
# constructed automatically.
assert use_mixed_loss is None
if use_mixed_loss is None:
self.losses = {
'types': [
paddleseg.models.CrossEntropyLoss(),
paddleseg.models.DiceLoss(),
paddleseg.models.DiceLoss()
],
'coef': [1.0, 0.05, 0.05]
}
Loading…
Cancel
Save