Merge pull request #12 from Bobholamovic/dev_gdf

[Feature] Add eight change detection models
own
Liu Yi 3 years ago committed by GitHub
commit d5a7ddaf6f
  1. 10
      paddlers/models/cd/models/__init__.py
  2. 13
      paddlers/models/cd/models/backbones/__init__.py
  3. 358
      paddlers/models/cd/models/backbones/resnet.py
  4. 395
      paddlers/models/cd/models/bit.py
  5. 96
      paddlers/models/cd/models/dsamnet.py
  6. 209
      paddlers/models/cd/models/dsifn.py
  7. 16
      paddlers/models/cd/models/layers/__init__.py
  8. 96
      paddlers/models/cd/models/layers/attention.py
  9. 142
      paddlers/models/cd/models/layers/blocks.py
  10. 86
      paddlers/models/cd/models/param_init.py
  11. 155
      paddlers/models/cd/models/snunet.py
  12. 298
      paddlers/models/cd/models/stanet.py
  13. 201
      paddlers/models/cd/models/unet_ef.py
  14. 224
      paddlers/models/cd/models/unet_siamconc.py
  15. 224
      paddlers/models/cd/models/unet_siamdiff.py
  16. 189
      paddlers/tasks/changedetector.py

@ -12,4 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .cdnet import CDNet
from .bit import BIT
from .cdnet import CDNet
from .dsifn import DSIFN
from .stanet import STANet
from .snunet import SNUNet
from .dsamnet import DSAMNet
from .unet_ef import UNetEarlyFusion
from .unet_siamconc import UNetSiamConc
from .unet_siamdiff import UNetSiamDiff

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,358 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py
## Original head information
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
__all__ = []
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'cf548f46534aa3560945be4b95cd11c4'),
'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',
'8d2275cf8706028345f78ac0e1d31969'),
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'ca6f485ee1ab0492d38f323885b0ad80'),
'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',
'02f35f034ca3858e1e54d4036443c92d'),
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
'7ad16a2f1e7333859ff986138630fd7a'),
}
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
self.conv1 = nn.Conv2D(
inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(nn.Layer):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BottleneckBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2D(
width,
width,
3,
padding=dilation,
stride=stride,
groups=groups,
dilation=dilation,
bias_attr=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2D(
width, planes * self.expansion, 1, bias_attr=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Layer):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
"""
def __init__(self, block, depth, num_classes=1000, with_pool=True, strides=(1,1,2,2,2), norm_layer=None):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.num_classes = num_classes
self.with_pool = with_pool
self._norm_layer = nn.BatchNorm2D if norm_layer is None else norm_layer
self.inplanes = 64
self.dilation = 1
self.conv1 = nn.Conv2D(
3,
self.inplanes,
kernel_size=7,
stride=strides[0],
padding=3,
bias_attr=False)
self.bn1 = self._norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[1])
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2])
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3])
self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4])
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
if num_classes > 0:
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
1,
stride=stride,
bias_attr=False),
norm_layer(planes * block.expansion), )
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, 1, 64,
previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
def resnet18(pretrained=False, **kwargs):
"""ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet18
# build model
model = resnet18()
# build model and load imagenet pretrained weight
# model = resnet18(pretrained=True)
"""
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
"""ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet34
# build model
model = resnet34()
# build model and load imagenet pretrained weight
# model = resnet34(pretrained=True)
"""
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
"""ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet50
# build model
model = resnet50()
# build model and load imagenet pretrained weight
# model = resnet50(pretrained=True)
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
"""ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet101
# build model
model = resnet101()
# build model and load imagenet pretrained weight
# model = resnet101(pretrained=True)
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
"""ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet152
# build model
model = resnet152()
# build model and load imagenet pretrained weight
# model = resnet152(pretrained=True)
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)

@ -0,0 +1,395 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal
from .backbones import resnet
from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity
from .param_init import KaimingInitMixin
class BIT(nn.Layer):
"""
The BIT implementation based on PaddlePaddle.
The original article refers to
H. Chen, et al., "Remote Sensing Image Change Detection With Transformers"
(https://arxiv.org/abs/2103.00208).
This implementation adopts pretrained encoders, as opposed to the original work where weights are randomly initialized.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
backbone (str, optional): The ResNet architecture that is used as the backbone. Currently, only 'resnet18' and
'resnet34' are supported. Default: 'resnet18'.
n_stages (int, optional): The number of ResNet stages used in the backbone, which should be a value in {3,4,5}.
Default: 4.
use_tokenizer (bool, optional): Use a tokenizer or not. Default: True.
token_len (int, optional): The length of input tokens. Default: 4.
pool_mode (str, optional): The pooling strategy to obtain input tokens when `use_tokenizer` is set to False. 'max'
for global max pooling and 'avg' for global average pooling. Default: 'max'.
pool_size (int, optional): The height and width of the pooled feature maps when `use_tokenizer` is set to False.
Default: 2.
enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the
encoder. Default: True.
enc_depth (int, optional): The number of attention blocks used in the encoder. Default: 1
enc_head_dim (int, optional): The embedding dimension of each encoder head. Default: 64.
dec_depth (int, optional): The number of attention blocks used in the decoder. Default: 8.
dec_head_dim (int, optional): The embedding dimension of each decoder head. Default: 8.
Raises:
ValueError: When an unsupported backbone type is specified, or the number of backbone stages is not 3, 4, or 5.
"""
def __init__(
self, in_channels, num_classes,
backbone='resnet18', n_stages=4,
use_tokenizer=True, token_len=4,
pool_mode='max', pool_size=2,
enc_with_pos=True,
enc_depth=1, enc_head_dim=64,
dec_depth=8, dec_head_dim=8,
**backbone_kwargs
):
super().__init__()
# TODO: reduce hard-coded parameters
DIM = 32
MLP_DIM = 2*DIM
EBD_DIM = DIM
self.backbone = Backbone(in_channels, EBD_DIM, arch=backbone, n_stages=n_stages, **backbone_kwargs)
self.use_tokenizer = use_tokenizer
if not use_tokenizer:
# If a tokenzier is not to be used,then downsample the feature maps.
self.pool_size = pool_size
self.pool_mode = pool_mode
self.token_len = pool_size * pool_size
else:
self.conv_att = Conv1x1(32, token_len, bias=False)
self.token_len = token_len
self.enc_with_pos = enc_with_pos
if enc_with_pos:
self.enc_pos_embedding = self.create_parameter(
shape=(1,self.token_len*2,EBD_DIM),
default_initializer=Normal()
)
self.enc_depth = enc_depth
self.dec_depth = dec_depth
self.enc_head_dim = enc_head_dim
self.dec_head_dim = dec_head_dim
self.encoder = TransformerEncoder(
dim=DIM,
depth=enc_depth,
n_heads=8,
head_dim=enc_head_dim,
mlp_dim=MLP_DIM,
dropout_rate=0.
)
self.decoder = TransformerDecoder(
dim=DIM,
depth=dec_depth,
n_heads=8,
head_dim=dec_head_dim,
mlp_dim=MLP_DIM,
dropout_rate=0.,
apply_softmax=True
)
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
self.conv_out = nn.Sequential(
Conv3x3(EBD_DIM, EBD_DIM, norm=True, act=True),
Conv3x3(EBD_DIM, num_classes)
)
def _get_semantic_tokens(self, x):
b, c = paddle.shape(x)[:2]
att_map = self.conv_att(x)
att_map = att_map.reshape((b,self.token_len,1,-1))
att_map = F.softmax(att_map, axis=-1)
x = x.reshape((b,1,c,-1))
tokens = (x*att_map).sum(-1)
return tokens
def _get_reshaped_tokens(self, x):
if self.pool_mode == 'max':
x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size))
elif self.pool_mode == 'avg':
x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size))
else:
x = x
tokens = x.transpose((0,2,3,1)).flatten(1,2)
return tokens
def encode(self, x):
if self.enc_with_pos:
x += self.enc_pos_embedding
x = self.encoder(x)
return x
def decode(self, x, m):
b, c, h, w = paddle.shape(x)
x = x.transpose((0,2,3,1)).flatten(1,2)
x = self.decoder(x, m)
x = x.transpose((0,2,1)).reshape((b,c,h,w))
return x
def forward(self, t1, t2):
# Extract features via shared backbone.
x1 = self.backbone(t1)
x2 = self.backbone(t2)
# Tokenization
if self.use_tokenizer:
token1 = self._get_semantic_tokens(x1)
token2 = self._get_semantic_tokens(x2)
else:
token1 = self._get_reshaped_tokens(x1)
token2 = self._get_reshaped_tokens(x2)
# Transformer encoder forward
token = paddle.concat([token1, token2], axis=1)
token = self.encode(token)
token1, token2 = paddle.chunk(token, 2, axis=1)
# Transformer decoder forward
y1 = self.decode(x1, token1)
y2 = self.decode(x2, token2)
# Feature differencing
y = paddle.abs(y1 - y2)
y = self.upsample(y)
# Classifier forward
pred = self.conv_out(y)
return pred,
def init_weight(self):
# Use the default initialization method.
pass
class Residual(nn.Layer):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class Residual2(nn.Layer):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x1, x2, **kwargs):
return self.fn(x1, x2, **kwargs) + x1
class PreNorm(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class PreNorm2(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x1, x2, **kwargs):
return self.fn(self.norm(x1), self.norm(x2), **kwargs)
class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout_rate=0.):
super().__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout_rate)
)
class CrossAttention(nn.Layer):
def __init__(self, dim, n_heads=8, head_dim=64, dropout_rate=0., apply_softmax=True):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.scale = dim ** -0.5
self.apply_softmax = apply_softmax
self.fc_q = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_k = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_v = nn.Linear(dim, inner_dim, bias_attr=False)
self.fc_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout_rate)
)
def forward(self, x, ref):
b, n = paddle.shape(x)[:2]
h = self.n_heads
q = self.fc_q(x)
k = self.fc_k(ref)
v = self.fc_v(ref)
q = q.reshape((b,n,h,-1)).transpose((0,2,1,3))
k = k.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
v = v.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
mult = paddle.matmul(q, k, transpose_y=True) * self.scale
if self.apply_softmax:
mult = F.softmax(mult, axis=-1)
out = paddle.matmul(mult, v)
out = out.transpose((0,2,1,3)).flatten(2)
return self.fc_out(out)
class SelfAttention(CrossAttention):
def forward(self, x):
return super().forward(x, x)
class TransformerEncoder(nn.Layer):
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate):
super().__init__()
self.layers = nn.LayerList([])
for _ in range(depth):
self.layers.append(nn.LayerList([
Residual(PreNorm(dim, SelfAttention(dim, n_heads, head_dim, dropout_rate))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
]))
def forward(self, x):
for att, ff in self.layers:
x = att(x)
x = ff(x)
return x
class TransformerDecoder(nn.Layer):
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate, apply_softmax=True):
super().__init__()
self.layers = nn.LayerList([])
for _ in range(depth):
self.layers.append(nn.LayerList([
Residual2(PreNorm2(dim, CrossAttention(dim, n_heads, head_dim, dropout_rate, apply_softmax))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)))
]))
def forward(self, x, m):
for att, ff in self.layers:
x = att(x, m)
x = ff(x)
return x
class Backbone(nn.Layer, KaimingInitMixin):
def __init__(
self,
in_ch, out_ch=32,
arch='resnet18',
pretrained=True,
n_stages=5
):
super().__init__()
expand = 1
strides = (2,1,2,1,1)
if arch == 'resnet18':
self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet34':
self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
else:
raise ValueError
self.n_stages = n_stages
if self.n_stages == 5:
itm_ch = 512 * expand
elif self.n_stages == 4:
itm_ch = 256 * expand
elif self.n_stages == 3:
itm_ch = 128 * expand
else:
raise ValueError
self.upsample = nn.Upsample(scale_factor=2)
self.conv_out = Conv3x3(itm_ch, out_ch)
self._trim_resnet()
if in_ch != 3:
self.resnet.conv1 = nn.Conv2D(
in_ch,
64,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False
)
if not pretrained:
self.init_weight()
def forward(self, x):
y = self.resnet.conv1(x)
y = self.resnet.bn1(y)
y = self.resnet.relu(y)
y = self.resnet.maxpool(y)
y = self.resnet.layer1(y)
y = self.resnet.layer2(y)
y = self.resnet.layer3(y)
y = self.resnet.layer4(y)
y = self.upsample(y)
return self.conv_out(y)
def _trim_resnet(self):
if self.n_stages > 5:
raise ValueError
if self.n_stages < 5:
self.resnet.layer4 = Identity()
if self.n_stages <= 3:
self.resnet.layer3 = Identity()
self.resnet.avgpool = Identity()
self.resnet.fc = Identity()

@ -0,0 +1,96 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import make_norm, Conv3x3, CBAM
from .stanet import Backbone, Decoder
class DSAMNet(nn.Layer):
"""
The DSAMNet implementation based on PaddlePaddle.
The original article refers to
Q. Shi, et al., "A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing
Change Detection"
(https://ieeexplore.ieee.org/document/9467555).
Note that this implementation differs from the original work in two aspects:
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
2. A classification head is used in place of the original metric learning-based head to stablize the training process.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
ca_ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8.
sa_kernel (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7.
"""
def __init__(self, in_channels, num_classes, ca_ratio=8, sa_kernel=7):
super().__init__()
WIDTH = 64
self.backbone = Backbone(in_ch=in_channels, arch='resnet18', strides=(1,1,2,2,1))
self.decoder = Decoder(WIDTH)
self.cbam1 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel)
self.cbam2 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel)
self.dsl2 = DSLayer(64, num_classes, 32, stride=2, output_padding=1)
self.dsl3 = DSLayer(128, num_classes, 32, stride=4, output_padding=3)
self.conv_out = nn.Sequential(
Conv3x3(WIDTH, WIDTH, norm=True, act=True),
Conv3x3(WIDTH, num_classes)
)
self.init_weight()
def forward(self, t1, t2):
f1 = self.backbone(t1)
f2 = self.backbone(t2)
y1 = self.decoder(f1)
y2 = self.decoder(f2)
y1 = self.cbam1(y1)
y2 = self.cbam2(y2)
out = paddle.abs(y1-y2)
out = F.interpolate(out, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
pred = self.conv_out(out)
ds2 = self.dsl2(paddle.abs(f1[0]-f2[0]))
ds3 = self.dsl3(paddle.abs(f1[1]-f2[1]))
return pred, ds2, ds3
def init_weight(self):
pass
class DSLayer(nn.Sequential):
def __init__(self, in_ch, out_ch, itm_ch, **convd_kwargs):
super().__init__(
nn.Conv2DTranspose(in_ch, itm_ch, kernel_size=3, padding=1, **convd_kwargs),
make_norm(itm_ch),
nn.ReLU(),
nn.Dropout2D(p=0.2),
nn.Conv2DTranspose(itm_ch, out_ch, kernel_size=3, padding=1)
)

@ -0,0 +1,209 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg16
from .layers import Conv1x1, make_norm, ChannelAttention, SpatialAttention
class DSIFN(nn.Layer):
"""
The DSIFN implementation based on PaddlePaddle.
The original article refers to
C. Zhang, et al., "A deeply supervised image fusion network for change detection in high resolution bi-temporal remote
sensing images"
(https://www.sciencedirect.com/science/article/pii/S0924271620301532).
Note that in this implementation, there is a flexible number of target classes.
Args:
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(self, num_classes, use_dropout=False):
super().__init__()
self.encoder1 = self.encoder2 = VGG16FeaturePicker()
self.sa1 = SpatialAttention()
self.sa2= SpatialAttention()
self.sa3 = SpatialAttention()
self.sa4 = SpatialAttention()
self.sa5 = SpatialAttention()
self.ca1 = ChannelAttention(in_ch=1024)
self.bn_ca1 = make_norm(1024)
self.o1_conv1 = conv2d_bn(1024, 512, use_dropout)
self.o1_conv2 = conv2d_bn(512, 512, use_dropout)
self.bn_sa1 = make_norm(512)
self.o1_conv3 = Conv1x1(512, num_classes)
self.trans_conv1 = nn.Conv2DTranspose(512, 512, kernel_size=2, stride=2)
self.ca2 = ChannelAttention(in_ch=1536)
self.bn_ca2 = make_norm(1536)
self.o2_conv1 = conv2d_bn(1536, 512, use_dropout)
self.o2_conv2 = conv2d_bn(512, 256, use_dropout)
self.o2_conv3 = conv2d_bn(256, 256, use_dropout)
self.bn_sa2 = make_norm(256)
self.o2_conv4 = Conv1x1(256, num_classes)
self.trans_conv2 = nn.Conv2DTranspose(256, 256, kernel_size=2, stride=2)
self.ca3 = ChannelAttention(in_ch=768)
self.o3_conv1 = conv2d_bn(768, 256, use_dropout)
self.o3_conv2 = conv2d_bn(256, 128, use_dropout)
self.o3_conv3 = conv2d_bn(128, 128, use_dropout)
self.bn_sa3 = make_norm(128)
self.o3_conv4 = Conv1x1(128, num_classes)
self.trans_conv3 = nn.Conv2DTranspose(128, 128, kernel_size=2, stride=2)
self.ca4 = ChannelAttention(in_ch=384)
self.o4_conv1 = conv2d_bn(384, 128, use_dropout)
self.o4_conv2 = conv2d_bn(128, 64, use_dropout)
self.o4_conv3 = conv2d_bn(64, 64, use_dropout)
self.bn_sa4 = make_norm(64)
self.o4_conv4 = Conv1x1(64, num_classes)
self.trans_conv4 = nn.Conv2DTranspose(64, 64, kernel_size=2, stride=2)
self.ca5 = ChannelAttention(in_ch=192)
self.o5_conv1 = conv2d_bn(192, 64, use_dropout)
self.o5_conv2 = conv2d_bn(64, 32, use_dropout)
self.o5_conv3 = conv2d_bn(32, 16, use_dropout)
self.bn_sa5 = make_norm(16)
self.o5_conv4 = Conv1x1(16, num_classes)
self.init_weight()
def forward(self, t1, t2):
# Extract bi-temporal features.
with paddle.no_grad():
self.encoder1.eval(), self.encoder2.eval()
t1_feats = self.encoder1(t1)
t2_feats = self.encoder2(t2)
t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_feats
t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29,= t2_feats
# Multi-level decoding
x = paddle.concat([t1_f_l29, t2_f_l29], axis=1)
x = self.o1_conv1(x)
x = self.o1_conv2(x)
x = self.sa1(x) * x
x = self.bn_sa1(x)
out1 = F.interpolate(
self.o1_conv3(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv1(x)
x = paddle.concat([x, t1_f_l22, t2_f_l22], axis=1)
x = self.ca2(x)*x
x = self.o2_conv1(x)
x = self.o2_conv2(x)
x = self.o2_conv3(x)
x = self.sa2(x) *x
x = self.bn_sa2(x)
out2 = F.interpolate(
self.o2_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv2(x)
x = paddle.concat([x, t1_f_l15, t2_f_l15], axis=1)
x = self.ca3(x)*x
x = self.o3_conv1(x)
x = self.o3_conv2(x)
x = self.o3_conv3(x)
x = self.sa3(x) *x
x = self.bn_sa3(x)
out3 = F.interpolate(
self.o3_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv3(x)
x = paddle.concat([x, t1_f_l8, t2_f_l8], axis=1)
x = self.ca4(x)*x
x = self.o4_conv1(x)
x = self.o4_conv2(x)
x = self.o4_conv3(x)
x = self.sa4(x) *x
x = self.bn_sa4(x)
out4 = F.interpolate(
self.o4_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv4(x)
x = paddle.concat([x, t1_f_l3, t2_f_l3], axis=1)
x = self.ca5(x)*x
x = self.o5_conv1(x)
x = self.o5_conv2(x)
x = self.o5_conv3(x)
x = self.sa5(x) *x
x = self.bn_sa5(x)
out5 = self.o5_conv4(x)
return out5, out4, out3, out2, out1
def init_weight(self):
# Do nothing
pass
class VGG16FeaturePicker(nn.Layer):
def __init__(self, indices=(3,8,15,22,29)):
super().__init__()
features = list(vgg16(pretrained=True).features)[:30]
self.features = nn.LayerList(features)
self.features.eval()
self.indices = set(indices)
def forward(self, x):
picked_feats = []
for idx, model in enumerate(self.features):
x = model(x)
if idx in self.indices:
picked_feats.append(x)
return picked_feats
def conv2d_bn(in_ch, out_ch, with_dropout=True):
lst = [
nn.Conv2D(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
nn.PReLU(),
make_norm(out_ch),
]
if with_dropout:
lst.append(nn.Dropout(p=0.6))
return nn.Sequential(*lst)

@ -0,0 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .blocks import *
from .attention import ChannelAttention, SpatialAttention, CBAM

@ -0,0 +1,96 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .blocks import Conv1x1, BasicConv
class ChannelAttention(nn.Layer):
"""
The channel attention module implementation based on PaddlePaddle.
The original article refers to
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module"
(https://arxiv.org/abs/1807.06521).
Args:
in_ch (int): The number of channels of the input features.
ratio (int, optional): The channel reduction ratio. Default: 8.
"""
def __init__(self, in_ch, ratio=8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.max_pool = nn.AdaptiveMaxPool2D(1)
self.fc1 = Conv1x1(in_ch, in_ch//ratio, bias=False, act=True)
self.fc2 = Conv1x1(in_ch//ratio, in_ch, bias=False)
def forward(self,x):
avg_out = self.fc2(self.fc1(self.avg_pool(x)))
max_out = self.fc2(self.fc1(self.max_pool(x)))
out = avg_out + max_out
return F.sigmoid(out)
class SpatialAttention(nn.Layer):
"""
The spatial attention module implementation based on PaddlePaddle.
The original article refers to
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module"
(https://arxiv.org/abs/1807.06521).
Args:
kernel_size (int, optional): The size of the convolutional kernel. Default: 7.
"""
def __init__(self, kernel_size=7):
super().__init__()
self.conv = BasicConv(2, 1, kernel_size, bias=False)
def forward(self, x):
avg_out = paddle.mean(x, axis=1, keepdim=True)
max_out = paddle.max(x, axis=1, keepdim=True)
x = paddle.concat([avg_out, max_out], axis=1)
x = self.conv(x)
return F.sigmoid(x)
class CBAM(nn.Layer):
"""
The CBAM implementation based on PaddlePaddle.
The original article refers to
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module"
(https://arxiv.org/abs/1807.06521).
Args:
in_ch (int): The number of channels of the input features.
ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8.
kernel_size (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7.
"""
def __init__(self, in_ch, ratio=8, kernel_size=7):
super().__init__()
self.ca = ChannelAttention(in_ch, ratio=ratio)
self.sa = SpatialAttention(kernel_size=kernel_size)
def forward(self, x):
y = self.ca(x) * x
y = self.sa(y) * y
return y

@ -0,0 +1,142 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
__all__ = [
'BasicConv', 'Conv1x1', 'Conv3x3', 'Conv7x7',
'MaxPool2x2', 'MaxUnPool2x2',
'ConvTransposed3x3',
'Identity',
'get_norm_layer', 'get_act_layer',
'make_norm', 'make_act'
]
def get_norm_layer():
# TODO: select appropriate norm layer.
return nn.BatchNorm2D
def get_act_layer():
# TODO: select appropriate activation layer.
return nn.ReLU
def make_norm(*args, **kwargs):
norm_layer = get_norm_layer()
return norm_layer(*args, **kwargs)
def make_act(*args, **kwargs):
act_layer = get_act_layer()
return act_layer(*args, **kwargs)
class BasicConv(nn.Layer):
def __init__(
self, in_ch, out_ch,
kernel_size, pad_mode='constant',
bias='auto', norm=False, act=False,
**kwargs
):
super().__init__()
seq = []
if kernel_size >= 2:
seq.append(nn.Pad2D(kernel_size//2, mode=pad_mode))
seq.append(
nn.Conv2D(
in_ch, out_ch, kernel_size,
stride=1, padding=0,
bias_attr=(False if norm else None) if bias=='auto' else bias,
**kwargs
)
)
if norm:
if norm is True:
norm = make_norm(out_ch)
seq.append(norm)
if act:
if act is True:
act = make_act()
seq.append(act)
self.seq = nn.Sequential(*seq)
def forward(self, x):
return self.seq(x)
class Conv1x1(BasicConv):
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs):
super().__init__(in_ch, out_ch, 1, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs)
class Conv3x3(BasicConv):
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs):
super().__init__(in_ch, out_ch, 3, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs)
class Conv7x7(BasicConv):
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs):
super().__init__(in_ch, out_ch, 7, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs)
class MaxPool2x2(nn.MaxPool2D):
def __init__(self, **kwargs):
super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs)
class MaxUnPool2x2(nn.MaxUnPool2D):
def __init__(self, **kwargs):
super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs)
class ConvTransposed3x3(nn.Layer):
def __init__(
self, in_ch, out_ch,
bias='auto', norm=False, act=False,
**kwargs
):
super().__init__()
seq = []
seq.append(
nn.Conv2DTranspose(
in_ch, out_ch, 3,
stride=2, padding=1,
bias_attr=(False if norm else None) if bias=='auto' else bias,
**kwargs
)
)
if norm:
if norm is True:
norm = make_norm(out_ch)
seq.append(norm)
if act:
if act is True:
act = make_act()
seq.append(act)
self.seq = nn.Sequential(*seq)
def forward(self, x):
return self.seq(x)
class Identity(nn.Layer):
"""A placeholder identity operator that accepts exactly one argument."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
return x

@ -0,0 +1,86 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
def normal_init(param, *args, **kwargs):
"""
Initialize parameters with a normal distribution.
Args:
param (Tensor): The tensor that needs to be initialized.
Returns:
The initialized parameters.
"""
return nn.initializer.Normal(*args, **kwargs)(param)
def kaiming_normal_init(param, *args, **kwargs):
"""
Initialize parameters with the Kaiming normal distribution.
For more information about the Kaiming initialization method, please refer to
https://arxiv.org/abs/1502.01852
Args:
param (Tensor): The tensor that needs to be initialized.
Returns:
The initialized parameters.
"""
return nn.initializer.KaimingNormal(*args, **kwargs)(param)
def constant_init(param, *args, **kwargs):
"""
Initialize parameters with constants.
Args:
param (Tensor): The tensor that needs to be initialized.
Returns:
The initialized parameters.
"""
return nn.initializer.Constant(*args, **kwargs)(param)
class KaimingInitMixin:
"""
A mix-in that provides the Kaiming initialization functionality.
Examples:
from paddlers.models.cd.models.param_init import KaimingInitMixin
class CustomNet(nn.Layer, KaimingInitMixin):
def __init__(self, num_channels, num_classes):
super().__init__()
self.conv = nn.Conv2D(num_channels, num_classes, 3, 1, 0, bias_attr=False)
self.bn = nn.BatchNorm2D(num_classes)
self.init_weight()
"""
def init_weight(self):
for layer in self.sublayers():
if isinstance(layer, nn.Conv2D):
kaiming_normal_init(layer.weight)
elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(layer.weight, value=1)
constant_init(layer.bias, value=0)

@ -0,0 +1,155 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv1x1, MaxPool2x2, make_norm, ChannelAttention
from .param_init import KaimingInitMixin
class SNUNet(nn.Layer, KaimingInitMixin):
"""
The SNUNet implementation based on PaddlePaddle.
The original article refers to
S. Fang, et al., "SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images"
(https://ieeexplore.ieee.org/document/9355573).
Note that bilinear interpolation is adopted as the upsampling method, which is different from the paper.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
width (int, optional): The output channels of the first convolutional layer. Default: 32.
"""
def __init__(self, in_channels, num_classes, width=32):
super().__init__()
filters = (width, width*2, width*4, width*8, width*16)
self.conv0_0 = ConvBlockNested(in_channels, filters[0], filters[0])
self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1])
self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2])
self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3])
self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4])
self.down1 = MaxPool2x2()
self.down2 = MaxPool2x2()
self.down3 = MaxPool2x2()
self.down4 = MaxPool2x2()
self.up1_0 = Up(filters[1])
self.up2_0 = Up(filters[2])
self.up3_0 = Up(filters[3])
self.up4_0 = Up(filters[4])
self.conv0_1 = ConvBlockNested(filters[0]*2+filters[1], filters[0], filters[0])
self.conv1_1 = ConvBlockNested(filters[1]*2+filters[2], filters[1], filters[1])
self.conv2_1 = ConvBlockNested(filters[2]*2+filters[3], filters[2], filters[2])
self.conv3_1 = ConvBlockNested(filters[3]*2+filters[4], filters[3], filters[3])
self.up1_1 = Up(filters[1])
self.up2_1 = Up(filters[2])
self.up3_1 = Up(filters[3])
self.conv0_2 = ConvBlockNested(filters[0]*3+filters[1], filters[0], filters[0])
self.conv1_2 = ConvBlockNested(filters[1]*3+filters[2], filters[1], filters[1])
self.conv2_2 = ConvBlockNested(filters[2]*3+filters[3], filters[2], filters[2])
self.up1_2 = Up(filters[1])
self.up2_2 = Up(filters[2])
self.conv0_3 = ConvBlockNested(filters[0]*4+filters[1], filters[0], filters[0])
self.conv1_3 = ConvBlockNested(filters[1]*4+filters[2], filters[1], filters[1])
self.up1_3 = Up(filters[1])
self.conv0_4 = ConvBlockNested(filters[0]*5+filters[1], filters[0], filters[0])
self.ca_intra = ChannelAttention(filters[0], ratio=4)
self.ca_inter = ChannelAttention(filters[0]*4, ratio=16)
self.conv_out = Conv1x1(filters[0]*4, num_classes)
self.init_weight()
def forward(self, t1, t2):
x0_0_t1 = self.conv0_0(t1)
x1_0_t1 = self.conv1_0(self.down1(x0_0_t1))
x2_0_t1 = self.conv2_0(self.down2(x1_0_t1))
x3_0_t1 = self.conv3_0(self.down3(x2_0_t1))
x0_0_t2 = self.conv0_0(t2)
x1_0_t2 = self.conv1_0(self.down1(x0_0_t2))
x2_0_t2 = self.conv2_0(self.down2(x1_0_t2))
x3_0_t2 = self.conv3_0(self.down3(x2_0_t2))
x4_0_t2 = self.conv4_0(self.down4(x3_0_t2))
x0_1 = self.conv0_1(paddle.concat([x0_0_t1, x0_0_t2, self.up1_0(x1_0_t2)], 1))
x1_1 = self.conv1_1(paddle.concat([x1_0_t1, x1_0_t2, self.up2_0(x2_0_t2)], 1))
x0_2 = self.conv0_2(paddle.concat([x0_0_t1, x0_0_t2, x0_1, self.up1_1(x1_1)], 1))
x2_1 = self.conv2_1(paddle.concat([x2_0_t1, x2_0_t2, self.up3_0(x3_0_t2)], 1))
x1_2 = self.conv1_2(paddle.concat([x1_0_t1, x1_0_t2, x1_1, self.up2_1(x2_1)], 1))
x0_3 = self.conv0_3(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, self.up1_2(x1_2)], 1))
x3_1 = self.conv3_1(paddle.concat([x3_0_t1, x3_0_t2, self.up4_0(x4_0_t2)], 1))
x2_2 = self.conv2_2(paddle.concat([x2_0_t1, x2_0_t2, x2_1, self.up3_1(x3_1)], 1))
x1_3 = self.conv1_3(paddle.concat([x1_0_t1, x1_0_t2, x1_1, x1_2, self.up2_2(x2_2)], 1))
x0_4 = self.conv0_4(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1))
out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1)
intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0)
m_intra = self.ca_intra(intra)
out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1,4,1,1)))
pred = self.conv_out(out)
return pred,
class ConvBlockNested(nn.Layer):
def __init__(self, in_ch, out_ch, mid_ch):
super().__init__()
self.act = nn.ReLU()
self.conv1 = nn.Conv2D(in_ch, mid_ch, kernel_size=3, padding=1)
self.bn1 = make_norm(mid_ch)
self.conv2 = nn.Conv2D(mid_ch, out_ch, kernel_size=3, padding=1)
self.bn2 = make_norm(out_ch)
def forward(self, x):
x = self.conv1(x)
identity = x
x = self.bn1(x)
x = self.act(x)
x = self.conv2(x)
x = self.bn2(x)
output = self.act(x + identity)
return output
class Up(nn.Layer):
def __init__(self, in_ch, use_conv=False):
super().__init__()
if use_conv:
self.up = nn.Conv2DTranspose(in_ch, in_ch, 2, stride=2)
else:
self.up = nn.Upsample(scale_factor=2,
mode='bilinear',
align_corners=True)
def forward(self, x):
x = self.up(x)
return x

@ -0,0 +1,298 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .backbones import resnet
from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
from .param_init import KaimingInitMixin
class STANet(nn.Layer):
"""
The STANet implementation based on PaddlePaddle.
The original article refers to
H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"
(https://www.mdpi.com/2072-4292/12/10/1662).
Note that this implementation differs from the original work in two aspects:
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
2. A classification head is used in place of the original metric learning-based head to stablize the training process.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.
ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values
greater than 1, the input features will first be processed by an average pooling layer with the kernel size of
`ds_factor`, before being used to calculate the attention scores. Default: 1.
Raises:
ValueError: When `att_type` has an illeagal value (unsupported attention type).
"""
def __init__(
self,
in_channels,
num_classes,
att_type='BAM',
ds_factor=1
):
super().__init__()
WIDTH = 64
self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)
self.attend = build_sta_module(in_ch=WIDTH, att_type=att_type, ds=ds_factor)
self.conv_out = nn.Sequential(
Conv3x3(WIDTH, WIDTH, norm=True, act=True),
Conv3x3(WIDTH, num_classes)
)
self.init_weight()
def forward(self, t1, t2):
f1 = self.extract(t1)
f2 = self.extract(t2)
f1, f2 = self.attend(f1, f2)
y = paddle.abs(f1- f2)
y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
pred = self.conv_out(y)
return pred,
def init_weight(self):
# Do nothing here as the encoder and decoder weights have already been initialized.
# Note however that currently self.attend and self.conv_out use the default initilization method.
pass
def build_feat_extractor(in_ch, width):
return nn.Sequential(
Backbone(in_ch, 'resnet18'),
Decoder(width)
)
def build_sta_module(in_ch, att_type, ds):
if att_type == 'BAM':
return Attention(BAM(in_ch, ds))
elif att_type == 'PAM':
return Attention(PAM(in_ch, ds))
else:
raise ValueError
class Backbone(nn.Layer, KaimingInitMixin):
def __init__(self, in_ch, arch, pretrained=True, strides=(2,1,2,2,2)):
super().__init__()
if arch == 'resnet18':
self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet34':
self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet50':
self.resnet = resnet.resnet50(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
else:
raise ValueError
self._trim_resnet()
if in_ch != 3:
self.resnet.conv1 = nn.Conv2D(
in_ch,
64,
kernel_size=7,
stride=strides[0],
padding=3,
bias_attr=False
)
if not pretrained:
self.init_weight()
def forward(self, x):
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x1 = self.resnet.layer1(x)
x2 = self.resnet.layer2(x1)
x3 = self.resnet.layer3(x2)
x4 = self.resnet.layer4(x3)
return x1, x2, x3, x4
def _trim_resnet(self):
self.resnet.avgpool = Identity()
self.resnet.fc = Identity()
class Decoder(nn.Layer, KaimingInitMixin):
def __init__(self, f_ch):
super().__init__()
self.dr1 = Conv1x1(64, 96, norm=True, act=True)
self.dr2 = Conv1x1(128, 96, norm=True, act=True)
self.dr3 = Conv1x1(256, 96, norm=True, act=True)
self.dr4 = Conv1x1(512, 96, norm=True, act=True)
self.conv_out = nn.Sequential(
Conv3x3(384, 256, norm=True, act=True),
nn.Dropout(0.5),
Conv1x1(256, f_ch, norm=True, act=True)
)
self.init_weight()
def forward(self, feats):
f1 = self.dr1(feats[0])
f2 = self.dr2(feats[1])
f3 = self.dr3(feats[2])
f4 = self.dr4(feats[3])
f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
x = paddle.concat([f1, f2, f3, f4], axis=1)
y = self.conv_out(x)
return y
class BAM(nn.Layer):
def __init__(self, in_ch, ds):
super().__init__()
self.ds = ds
self.pool = nn.AvgPool2D(self.ds)
self.val_ch = in_ch
self.key_ch = in_ch // 8
self.conv_q = Conv1x1(in_ch, self.key_ch)
self.conv_k = Conv1x1(in_ch, self.key_ch)
self.conv_v = Conv1x1(in_ch, self.val_ch)
self.softmax = nn.Softmax(axis=-1)
def forward(self, x):
x = x.flatten(-2)
x_rs = self.pool(x)
b, c, h, w = paddle.shape(x_rs)
query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
key = self.conv_k(x_rs).reshape((b,-1,h*w))
energy = paddle.bmm(query, key)
energy = (self.key_ch**(-0.5)) * energy
attention = self.softmax(energy)
value = self.conv_v(x_rs).reshape((b,-1,w*h))
out = paddle.bmm(value, attention.transpose((0,2,1)))
out = out.reshape((b,c,h,w))
out = F.interpolate(out, scale_factor=self.ds)
out = out + x
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
class PAMBlock(nn.Layer):
def __init__(self, in_ch, scale=1, ds=1):
super().__init__()
self.scale = scale
self.ds = ds
self.pool = nn.AvgPool2D(self.ds)
self.val_ch = in_ch
self.key_ch = in_ch // 8
self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)
self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)
self.conv_v = Conv1x1(in_ch, self.val_ch)
def forward(self, x):
x_rs = self.pool(x)
# Get query, key, and value.
query = self.conv_q(x_rs)
key = self.conv_k(x_rs)
value = self.conv_v(x_rs)
# Split the whole image into subregions.
b, c, h, w = paddle.shape(x_rs)
query = self._split_subregions(query)
key = self._split_subregions(key)
value = self._split_subregions(value)
# Perform subregion-wise attention.
out = self._attend(query, key, value)
# Stack subregions to reconstruct the whole image.
out = self._recons_whole(out, b, c, h, w)
out = F.interpolate(out, scale_factor=self.ds)
return out
def _attend(self, query, key, value):
energy = paddle.bmm(query.transpose((0,2,1)), key) # batch matrix multiplication
energy = (self.key_ch**(-0.5)) * energy
attention = F.softmax(energy, axis=-1)
out = paddle.bmm(value, attention.transpose((0,2,1)))
return out
def _split_subregions(self, x):
b, c, h, w = paddle.shape(x)
assert h % self.scale == 0 and w % self.scale == 0
x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))
return x
def _recons_whole(self, x, b, c, h, w):
x = x.reshape((b, self.scale, self.scale, c, h//self.scale, w//self.scale))
x = x.transpose((0,3,1,4,2,5)).reshape((b, c, h, w))
return x
class PAM(nn.Layer):
def __init__(self, in_ch, ds, scales=(1,2,4,8)):
super().__init__()
self.stages = nn.LayerList([
PAMBlock(in_ch, scale=s, ds=ds)
for s in scales
])
self.conv_out = Conv1x1(in_ch*len(scales), in_ch, bias=False)
def forward(self, x):
x = x.flatten(-2)
res = [stage(x) for stage in self.stages]
out = self.conv_out(paddle.concat(res, axis=1))
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
class Attention(nn.Layer):
def __init__(self, att):
super().__init__()
self.att = att
def forward(self, x1, x2):
x = paddle.stack([x1, x2], axis=-1)
y = self.att(x)
return y[...,0], y[...,1]

@ -0,0 +1,201 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
from .param_init import normal_init, constant_init
class UNetEarlyFusion(nn.Layer):
"""
The FC-EF implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(
self,
in_channels,
num_classes,
use_dropout=False
):
super().__init__()
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
self.use_dropout = use_dropout
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
self.do11 = self._make_dropout()
self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
self.do12 = self._make_dropout()
self.pool1 = MaxPool2x2()
self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
self.do21 = self._make_dropout()
self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
self.do22 = self._make_dropout()
self.pool2 = MaxPool2x2()
self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
self.do31 = self._make_dropout()
self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
self.do32 = self._make_dropout()
self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
self.do33 = self._make_dropout()
self.pool3 = MaxPool2x2()
self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
self.do41 = self._make_dropout()
self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
self.do42 = self._make_dropout()
self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
self.do43 = self._make_dropout()
self.pool4 = MaxPool2x2()
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
self.conv43d = Conv3x3(C5, C4, norm=True, act=True)
self.do43d = self._make_dropout()
self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
self.do42d = self._make_dropout()
self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
self.do41d = self._make_dropout()
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
self.conv33d = Conv3x3(C4, C3, norm=True, act=True)
self.do33d = self._make_dropout()
self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
self.do32d = self._make_dropout()
self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
self.do31d = self._make_dropout()
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
self.conv22d = Conv3x3(C3, C2, norm=True, act=True)
self.do22d = self._make_dropout()
self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
self.do21d = self._make_dropout()
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
self.conv12d = Conv3x3(C2, C1, norm=True, act=True)
self.do12d = self._make_dropout()
self.conv11d = Conv3x3(C1, num_classes)
self.init_weight()
def forward(self, t1, t2):
x = paddle.concat([t1, t2], axis=1)
# Stage 1
x11 = self.do11(self.conv11(x))
x12 = self.do12(self.conv12(x11))
x1p = self.pool1(x12)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22 = self.do22(self.conv22(x21))
x2p = self.pool2(x22)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33 = self.do33(self.conv33(x32))
x3p = self.pool3(x33)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43 = self.do43(self.conv43(x42))
x4p = self.pool4(x43)
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (
0,
paddle.shape(x43)[3]-paddle.shape(x4d)[3],
0,
paddle.shape(x43)[2]-paddle.shape(x4d)[2]
)
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1)
x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d))
x41d = self.do41d(self.conv41d(x42d))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (
0,
paddle.shape(x33)[3]-paddle.shape(x3d)[3],
0,
paddle.shape(x33)[2]-paddle.shape(x3d)[2]
)
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1)
x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d))
x31d = self.do31d(self.conv31d(x32d))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (
0,
paddle.shape(x22)[3]-paddle.shape(x2d)[3],
0,
paddle.shape(x22)[2]-paddle.shape(x2d)[2]
)
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1)
x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (
0,
paddle.shape(x12)[3]-paddle.shape(x1d)[3],
0,
paddle.shape(x12)[2]-paddle.shape(x1d)[2]
)
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1)
x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d)
return x11d,
def init_weight(self):
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2D):
normal_init(sublayer.weight, std=0.001)
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(sublayer.weight, value=1.0)
constant_init(sublayer.bias, value=0.0)
def _make_dropout(self):
if self.use_dropout:
return nn.Dropout2D(p=0.2)
else:
return Identity()

@ -0,0 +1,224 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
from .param_init import normal_init, constant_init
class UNetSiamConc(nn.Layer):
"""
The FC-Siam-conc implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(
self,
in_channels,
num_classes,
use_dropout=False
):
super().__init__()
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
self.use_dropout = use_dropout
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
self.do11 = self._make_dropout()
self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
self.do12 = self._make_dropout()
self.pool1 = MaxPool2x2()
self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
self.do21 = self._make_dropout()
self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
self.do22 = self._make_dropout()
self.pool2 = MaxPool2x2()
self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
self.do31 = self._make_dropout()
self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
self.do32 = self._make_dropout()
self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
self.do33 = self._make_dropout()
self.pool3 = MaxPool2x2()
self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
self.do41 = self._make_dropout()
self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
self.do42 = self._make_dropout()
self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
self.do43 = self._make_dropout()
self.pool4 = MaxPool2x2()
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
self.conv43d = Conv3x3(C5+C4, C4, norm=True, act=True)
self.do43d = self._make_dropout()
self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
self.do42d = self._make_dropout()
self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
self.do41d = self._make_dropout()
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
self.conv33d = Conv3x3(C4+C3, C3, norm=True, act=True)
self.do33d = self._make_dropout()
self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
self.do32d = self._make_dropout()
self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
self.do31d = self._make_dropout()
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
self.conv22d = Conv3x3(C3+C2, C2, norm=True, act=True)
self.do22d = self._make_dropout()
self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
self.do21d = self._make_dropout()
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
self.conv12d = Conv3x3(C2+C1, C1, norm=True, act=True)
self.do12d = self._make_dropout()
self.conv11d = Conv3x3(C1, num_classes)
self.init_weight()
def forward(self, t1, t2):
# Encode t1
# Stage 1
x11 = self.do11(self.conv11(t1))
x12_1 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_1)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_1 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_1)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_1 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_1)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_1 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_1)
# Encode t2
# Stage 1
x11 = self.do11(self.conv11(t2))
x12_2 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_2)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_2 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_2)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_2 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_2)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_2 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_2)
# Decode
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (
0,
paddle.shape(x43_1)[3]-paddle.shape(x4d)[3],
0,
paddle.shape(x43_1)[2]-paddle.shape(x4d)[2]
)
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1)
x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d))
x41d = self.do41d(self.conv41d(x42d))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (
0,
paddle.shape(x33_1)[3]-paddle.shape(x3d)[3],
0,
paddle.shape(x33_1)[2]-paddle.shape(x3d)[2]
)
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1)
x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d))
x31d = self.do31d(self.conv31d(x32d))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (
0,
paddle.shape(x22_1)[3]-paddle.shape(x2d)[3],
0,
paddle.shape(x22_1)[2]-paddle.shape(x2d)[2]
)
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1)
x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (
0,
paddle.shape(x12_1)[3]-paddle.shape(x1d)[3],
0,
paddle.shape(x12_1)[2]-paddle.shape(x1d)[2]
)
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1)
x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d)
return x11d,
def init_weight(self):
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2D):
normal_init(sublayer.weight, std=0.001)
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(sublayer.weight, value=1.0)
constant_init(sublayer.bias, value=0.0)
def _make_dropout(self):
if self.use_dropout:
return nn.Dropout2D(p=0.2)
else:
return Identity()

@ -0,0 +1,224 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity
from .param_init import normal_init, constant_init
class UNetSiamDiff(nn.Layer):
"""
The FC-Siam-diff implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(
self,
in_channels,
num_classes,
use_dropout=False
):
super().__init__()
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256
self.use_dropout = use_dropout
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True)
self.do11 = self._make_dropout()
self.conv12 = Conv3x3(C1, C1, norm=True, act=True)
self.do12 = self._make_dropout()
self.pool1 = MaxPool2x2()
self.conv21 = Conv3x3(C1, C2, norm=True, act=True)
self.do21 = self._make_dropout()
self.conv22 = Conv3x3(C2, C2, norm=True, act=True)
self.do22 = self._make_dropout()
self.pool2 = MaxPool2x2()
self.conv31 = Conv3x3(C2, C3, norm=True, act=True)
self.do31 = self._make_dropout()
self.conv32 = Conv3x3(C3, C3, norm=True, act=True)
self.do32 = self._make_dropout()
self.conv33 = Conv3x3(C3, C3, norm=True, act=True)
self.do33 = self._make_dropout()
self.pool3 = MaxPool2x2()
self.conv41 = Conv3x3(C3, C4, norm=True, act=True)
self.do41 = self._make_dropout()
self.conv42 = Conv3x3(C4, C4, norm=True, act=True)
self.do42 = self._make_dropout()
self.conv43 = Conv3x3(C4, C4, norm=True, act=True)
self.do43 = self._make_dropout()
self.pool4 = MaxPool2x2()
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1)
self.conv43d = Conv3x3(C5, C4, norm=True, act=True)
self.do43d = self._make_dropout()
self.conv42d = Conv3x3(C4, C4, norm=True, act=True)
self.do42d = self._make_dropout()
self.conv41d = Conv3x3(C4, C3, norm=True, act=True)
self.do41d = self._make_dropout()
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1)
self.conv33d = Conv3x3(C4, C3, norm=True, act=True)
self.do33d = self._make_dropout()
self.conv32d = Conv3x3(C3, C3, norm=True, act=True)
self.do32d = self._make_dropout()
self.conv31d = Conv3x3(C3, C2, norm=True, act=True)
self.do31d = self._make_dropout()
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1)
self.conv22d = Conv3x3(C3, C2, norm=True, act=True)
self.do22d = self._make_dropout()
self.conv21d = Conv3x3(C2, C1, norm=True, act=True)
self.do21d = self._make_dropout()
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1)
self.conv12d = Conv3x3(C2, C1, norm=True, act=True)
self.do12d = self._make_dropout()
self.conv11d = Conv3x3(C1, num_classes)
self.init_weight()
def forward(self, t1, t2):
# Encode t1
# Stage 1
x11 = self.do11(self.conv11(t1))
x12_1 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_1)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_1 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_1)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_1 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_1)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_1 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_1)
# Encode t2
# Stage 1
x11 = self.do11(self.conv11(t2))
x12_2 = self.do12(self.conv12(x11))
x1p = self.pool1(x12_2)
# Stage 2
x21 = self.do21(self.conv21(x1p))
x22_2 = self.do22(self.conv22(x21))
x2p = self.pool2(x22_2)
# Stage 3
x31 = self.do31(self.conv31(x2p))
x32 = self.do32(self.conv32(x31))
x33_2 = self.do33(self.conv33(x32))
x3p = self.pool3(x33_2)
# Stage 4
x41 = self.do41(self.conv41(x3p))
x42 = self.do42(self.conv42(x41))
x43_2 = self.do43(self.conv43(x42))
x4p = self.pool4(x43_2)
# Decode
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = (
0,
paddle.shape(x43_1)[3]-paddle.shape(x4d)[3],
0,
paddle.shape(x43_1)[2]-paddle.shape(x4d)[2]
)
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), paddle.abs(x43_1-x43_2)], 1)
x43d = self.do43d(self.conv43d(x4d))
x42d = self.do42d(self.conv42d(x43d))
x41d = self.do41d(self.conv41d(x42d))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = (
0,
paddle.shape(x33_1)[3]-paddle.shape(x3d)[3],
0,
paddle.shape(x33_1)[2]-paddle.shape(x3d)[2]
)
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), paddle.abs(x33_1-x33_2)], 1)
x33d = self.do33d(self.conv33d(x3d))
x32d = self.do32d(self.conv32d(x33d))
x31d = self.do31d(self.conv31d(x32d))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = (
0,
paddle.shape(x22_1)[3]-paddle.shape(x2d)[3],
0,
paddle.shape(x22_1)[2]-paddle.shape(x2d)[2]
)
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), paddle.abs(x22_1-x22_2)], 1)
x22d = self.do22d(self.conv22d(x2d))
x21d = self.do21d(self.conv21d(x22d))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = (
0,
paddle.shape(x12_1)[3]-paddle.shape(x1d)[3],
0,
paddle.shape(x12_1)[2]-paddle.shape(x1d)[2]
)
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), paddle.abs(x12_1-x12_2)], 1)
x12d = self.do12d(self.conv12d(x1d))
x11d = self.conv11d(x12d)
return x11d,
def init_weight(self):
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2D):
normal_init(sublayer.weight, std=0.001)
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
constant_init(sublayer.weight, value=1.0)
constant_init(sublayer.bias, value=0.0)
def _make_dropout(self):
if self.use_dropout:
return nn.Dropout2D(p=0.2)
else:
return Identity()

@ -31,7 +31,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from paddlers.transforms import ImgDecoder, Resize
import paddlers.models.cd as cd
__all__ = ["CDNet"]
__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT", "SNUNet", "DSIFN", "DSAMNet"]
class BaseChangeDetector(BaseModel):
@ -663,3 +663,190 @@ class CDNet(BaseChangeDetector):
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class UNetEarlyFusion(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=6,
use_dropout=False,
**params):
params.update({
'in_channels': in_channels,
'use_dropout': use_dropout
})
super(UNetEarlyFusion, self).__init__(
model_name='UNetEarlyFusion',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class UNetSiamConc(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
use_dropout=False,
**params):
params.update({
'in_channels': in_channels,
'use_dropout': use_dropout
})
super(UNetSiamConc, self).__init__(
model_name='UNetSiamConc',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class UNetSiamDiff(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
use_dropout=False,
**params):
params.update({
'in_channels': in_channels,
'use_dropout': use_dropout
})
super(UNetSiamDiff, self).__init__(
model_name='UNetSiamDiff',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class STANet(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
att_type='BAM',
ds_factor=1,
**params):
params.update({
'in_channels': in_channels,
'att_type': att_type,
'ds_factor': ds_factor
})
super(STANet, self).__init__(
model_name='STANet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class BIT(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
backbone='resnet18',
n_stages=4,
use_tokenizer=True,
token_len=4,
pool_mode='max',
pool_size=2,
enc_with_pos=True,
enc_depth=1,
enc_head_dim=64,
dec_depth=8,
dec_head_dim=8,
**params):
params.update({
'in_channels': in_channels,
'backbone': backbone,
'n_stages': n_stages,
'use_tokenizer': use_tokenizer,
'token_len': token_len,
'pool_mode': pool_mode,
'pool_size': pool_size,
'enc_with_pos': enc_with_pos,
'enc_depth': enc_depth,
'enc_head_dim': enc_head_dim,
'dec_depth': dec_depth,
'dec_head_dim': dec_head_dim
})
super(BIT, self).__init__(
model_name='BIT',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class SNUNet(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
width=32,
**params):
params.update({
'in_channels': in_channels,
'width': width
})
super(SNUNet, self).__init__(
model_name='SNUNet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class DSIFN(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=None,
use_dropout=False,
**params):
params.update({
'use_dropout': use_dropout
})
super(DSIFN, self).__init__(
model_name='DSIFN',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
# HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are
# constructed automatically.
assert use_mixed_loss is None
if use_mixed_loss is None:
self.losses = {
# XXX: make sure the shallow copy works correctly here.
'types': [paddleseg.models.CrossEntropyLoss()]*5,
'coef': [1.0]*5
}
class DSAMNet(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=None,
in_channels=3,
ca_ratio=8,
sa_kernel=7,
**params):
params.update({
'in_channels': in_channels,
'ca_ratio': ca_ratio,
'sa_kernel': sa_kernel
})
super(DSAMNet, self).__init__(
model_name='DSAMNet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
# HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are
# constructed automatically.
assert use_mixed_loss is None
if use_mixed_loss is None:
self.losses = {
'types': [
paddleseg.models.CrossEntropyLoss(),
paddleseg.models.DiceLoss(),
paddleseg.models.DiceLoss()
],
'coef': [1.0, 0.05, 0.05]
}
Loading…
Cancel
Save