commit
d5a7ddaf6f
16 changed files with 2710 additions and 2 deletions
@ -0,0 +1,13 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
@ -0,0 +1,358 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py |
||||
## Original head information |
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from __future__ import division |
||||
from __future__ import print_function |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
|
||||
from paddle.utils.download import get_weights_path_from_url |
||||
|
||||
__all__ = [] |
||||
|
||||
model_urls = { |
||||
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', |
||||
'cf548f46534aa3560945be4b95cd11c4'), |
||||
'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams', |
||||
'8d2275cf8706028345f78ac0e1d31969'), |
||||
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', |
||||
'ca6f485ee1ab0492d38f323885b0ad80'), |
||||
'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams', |
||||
'02f35f034ca3858e1e54d4036443c92d'), |
||||
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', |
||||
'7ad16a2f1e7333859ff986138630fd7a'), |
||||
} |
||||
|
||||
|
||||
class BasicBlock(nn.Layer): |
||||
expansion = 1 |
||||
|
||||
def __init__(self, |
||||
inplanes, |
||||
planes, |
||||
stride=1, |
||||
downsample=None, |
||||
groups=1, |
||||
base_width=64, |
||||
dilation=1, |
||||
norm_layer=None): |
||||
super(BasicBlock, self).__init__() |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2D |
||||
|
||||
if dilation > 1: |
||||
raise NotImplementedError( |
||||
"Dilation > 1 not supported in BasicBlock") |
||||
|
||||
self.conv1 = nn.Conv2D( |
||||
inplanes, planes, 3, padding=1, stride=stride, bias_attr=False) |
||||
self.bn1 = norm_layer(planes) |
||||
self.relu = nn.ReLU() |
||||
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) |
||||
self.bn2 = norm_layer(planes) |
||||
self.downsample = downsample |
||||
self.stride = stride |
||||
|
||||
def forward(self, x): |
||||
identity = x |
||||
|
||||
out = self.conv1(x) |
||||
out = self.bn1(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv2(out) |
||||
out = self.bn2(out) |
||||
|
||||
if self.downsample is not None: |
||||
identity = self.downsample(x) |
||||
|
||||
out += identity |
||||
out = self.relu(out) |
||||
|
||||
return out |
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer): |
||||
|
||||
expansion = 4 |
||||
|
||||
def __init__(self, |
||||
inplanes, |
||||
planes, |
||||
stride=1, |
||||
downsample=None, |
||||
groups=1, |
||||
base_width=64, |
||||
dilation=1, |
||||
norm_layer=None): |
||||
super(BottleneckBlock, self).__init__() |
||||
if norm_layer is None: |
||||
norm_layer = nn.BatchNorm2D |
||||
width = int(planes * (base_width / 64.)) * groups |
||||
|
||||
self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False) |
||||
self.bn1 = norm_layer(width) |
||||
|
||||
self.conv2 = nn.Conv2D( |
||||
width, |
||||
width, |
||||
3, |
||||
padding=dilation, |
||||
stride=stride, |
||||
groups=groups, |
||||
dilation=dilation, |
||||
bias_attr=False) |
||||
self.bn2 = norm_layer(width) |
||||
|
||||
self.conv3 = nn.Conv2D( |
||||
width, planes * self.expansion, 1, bias_attr=False) |
||||
self.bn3 = norm_layer(planes * self.expansion) |
||||
self.relu = nn.ReLU() |
||||
self.downsample = downsample |
||||
self.stride = stride |
||||
|
||||
def forward(self, x): |
||||
identity = x |
||||
|
||||
out = self.conv1(x) |
||||
out = self.bn1(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv2(out) |
||||
out = self.bn2(out) |
||||
out = self.relu(out) |
||||
|
||||
out = self.conv3(out) |
||||
out = self.bn3(out) |
||||
|
||||
if self.downsample is not None: |
||||
identity = self.downsample(x) |
||||
|
||||
out += identity |
||||
out = self.relu(out) |
||||
|
||||
return out |
||||
|
||||
|
||||
class ResNet(nn.Layer): |
||||
"""ResNet model from |
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ |
||||
Args: |
||||
Block (BasicBlock|BottleneckBlock): block module of model. |
||||
depth (int): layers of resnet, default: 50. |
||||
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer |
||||
will not be defined. Default: 1000. |
||||
with_pool (bool): use pool before the last fc layer or not. Default: True. |
||||
Examples: |
||||
.. code-block:: python |
||||
from paddle.vision.models import ResNet |
||||
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock |
||||
resnet50 = ResNet(BottleneckBlock, 50) |
||||
resnet18 = ResNet(BasicBlock, 18) |
||||
""" |
||||
|
||||
def __init__(self, block, depth, num_classes=1000, with_pool=True, strides=(1,1,2,2,2), norm_layer=None): |
||||
super(ResNet, self).__init__() |
||||
layer_cfg = { |
||||
18: [2, 2, 2, 2], |
||||
34: [3, 4, 6, 3], |
||||
50: [3, 4, 6, 3], |
||||
101: [3, 4, 23, 3], |
||||
152: [3, 8, 36, 3] |
||||
} |
||||
layers = layer_cfg[depth] |
||||
self.num_classes = num_classes |
||||
self.with_pool = with_pool |
||||
self._norm_layer = nn.BatchNorm2D if norm_layer is None else norm_layer |
||||
|
||||
self.inplanes = 64 |
||||
self.dilation = 1 |
||||
|
||||
self.conv1 = nn.Conv2D( |
||||
3, |
||||
self.inplanes, |
||||
kernel_size=7, |
||||
stride=strides[0], |
||||
padding=3, |
||||
bias_attr=False) |
||||
self.bn1 = self._norm_layer(self.inplanes) |
||||
self.relu = nn.ReLU() |
||||
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) |
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[1]) |
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2]) |
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3]) |
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4]) |
||||
if with_pool: |
||||
self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) |
||||
|
||||
if num_classes > 0: |
||||
self.fc = nn.Linear(512 * block.expansion, num_classes) |
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): |
||||
norm_layer = self._norm_layer |
||||
downsample = None |
||||
previous_dilation = self.dilation |
||||
if dilate: |
||||
self.dilation *= stride |
||||
stride = 1 |
||||
if stride != 1 or self.inplanes != planes * block.expansion: |
||||
downsample = nn.Sequential( |
||||
nn.Conv2D( |
||||
self.inplanes, |
||||
planes * block.expansion, |
||||
1, |
||||
stride=stride, |
||||
bias_attr=False), |
||||
norm_layer(planes * block.expansion), ) |
||||
|
||||
layers = [] |
||||
layers.append( |
||||
block(self.inplanes, planes, stride, downsample, 1, 64, |
||||
previous_dilation, norm_layer)) |
||||
self.inplanes = planes * block.expansion |
||||
for _ in range(1, blocks): |
||||
layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) |
||||
|
||||
return nn.Sequential(*layers) |
||||
|
||||
def forward(self, x): |
||||
x = self.conv1(x) |
||||
x = self.bn1(x) |
||||
x = self.relu(x) |
||||
x = self.maxpool(x) |
||||
x = self.layer1(x) |
||||
x = self.layer2(x) |
||||
x = self.layer3(x) |
||||
x = self.layer4(x) |
||||
|
||||
if self.with_pool: |
||||
x = self.avgpool(x) |
||||
|
||||
if self.num_classes > 0: |
||||
x = paddle.flatten(x, 1) |
||||
x = self.fc(x) |
||||
|
||||
return x |
||||
|
||||
|
||||
def _resnet(arch, Block, depth, pretrained, **kwargs): |
||||
model = ResNet(Block, depth, **kwargs) |
||||
if pretrained: |
||||
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( |
||||
arch) |
||||
weight_path = get_weights_path_from_url(model_urls[arch][0], |
||||
model_urls[arch][1]) |
||||
|
||||
param = paddle.load(weight_path) |
||||
model.set_dict(param) |
||||
|
||||
return model |
||||
|
||||
|
||||
def resnet18(pretrained=False, **kwargs): |
||||
"""ResNet 18-layer model |
||||
|
||||
Args: |
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet |
||||
Examples: |
||||
.. code-block:: python |
||||
from paddle.vision.models import resnet18 |
||||
# build model |
||||
model = resnet18() |
||||
# build model and load imagenet pretrained weight |
||||
# model = resnet18(pretrained=True) |
||||
""" |
||||
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) |
||||
|
||||
|
||||
def resnet34(pretrained=False, **kwargs): |
||||
"""ResNet 34-layer model |
||||
|
||||
Args: |
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet |
||||
|
||||
Examples: |
||||
.. code-block:: python |
||||
from paddle.vision.models import resnet34 |
||||
# build model |
||||
model = resnet34() |
||||
# build model and load imagenet pretrained weight |
||||
# model = resnet34(pretrained=True) |
||||
""" |
||||
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) |
||||
|
||||
|
||||
def resnet50(pretrained=False, **kwargs): |
||||
"""ResNet 50-layer model |
||||
|
||||
Args: |
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet |
||||
Examples: |
||||
.. code-block:: python |
||||
from paddle.vision.models import resnet50 |
||||
# build model |
||||
model = resnet50() |
||||
# build model and load imagenet pretrained weight |
||||
# model = resnet50(pretrained=True) |
||||
""" |
||||
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) |
||||
|
||||
|
||||
def resnet101(pretrained=False, **kwargs): |
||||
"""ResNet 101-layer model |
||||
|
||||
Args: |
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet |
||||
Examples: |
||||
.. code-block:: python |
||||
from paddle.vision.models import resnet101 |
||||
# build model |
||||
model = resnet101() |
||||
# build model and load imagenet pretrained weight |
||||
# model = resnet101(pretrained=True) |
||||
""" |
||||
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) |
||||
|
||||
|
||||
def resnet152(pretrained=False, **kwargs): |
||||
"""ResNet 152-layer model |
||||
|
||||
Args: |
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet |
||||
Examples: |
||||
.. code-block:: python |
||||
from paddle.vision.models import resnet152 |
||||
# build model |
||||
model = resnet152() |
||||
# build model and load imagenet pretrained weight |
||||
# model = resnet152(pretrained=True) |
||||
""" |
||||
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) |
@ -0,0 +1,395 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
from paddle.nn.initializer import Normal |
||||
|
||||
|
||||
from .backbones import resnet |
||||
from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity |
||||
from .param_init import KaimingInitMixin |
||||
|
||||
|
||||
class BIT(nn.Layer): |
||||
""" |
||||
The BIT implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
H. Chen, et al., "Remote Sensing Image Change Detection With Transformers" |
||||
(https://arxiv.org/abs/2103.00208). |
||||
|
||||
This implementation adopts pretrained encoders, as opposed to the original work where weights are randomly initialized. |
||||
|
||||
Args: |
||||
in_channels (int): The number of bands of the input images. |
||||
num_classes (int): The number of target classes. |
||||
backbone (str, optional): The ResNet architecture that is used as the backbone. Currently, only 'resnet18' and |
||||
'resnet34' are supported. Default: 'resnet18'. |
||||
n_stages (int, optional): The number of ResNet stages used in the backbone, which should be a value in {3,4,5}. |
||||
Default: 4. |
||||
use_tokenizer (bool, optional): Use a tokenizer or not. Default: True. |
||||
token_len (int, optional): The length of input tokens. Default: 4. |
||||
pool_mode (str, optional): The pooling strategy to obtain input tokens when `use_tokenizer` is set to False. 'max' |
||||
for global max pooling and 'avg' for global average pooling. Default: 'max'. |
||||
pool_size (int, optional): The height and width of the pooled feature maps when `use_tokenizer` is set to False. |
||||
Default: 2. |
||||
enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the |
||||
encoder. Default: True. |
||||
enc_depth (int, optional): The number of attention blocks used in the encoder. Default: 1 |
||||
enc_head_dim (int, optional): The embedding dimension of each encoder head. Default: 64. |
||||
dec_depth (int, optional): The number of attention blocks used in the decoder. Default: 8. |
||||
dec_head_dim (int, optional): The embedding dimension of each decoder head. Default: 8. |
||||
|
||||
Raises: |
||||
ValueError: When an unsupported backbone type is specified, or the number of backbone stages is not 3, 4, or 5. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, in_channels, num_classes, |
||||
backbone='resnet18', n_stages=4, |
||||
use_tokenizer=True, token_len=4, |
||||
pool_mode='max', pool_size=2, |
||||
enc_with_pos=True, |
||||
enc_depth=1, enc_head_dim=64, |
||||
dec_depth=8, dec_head_dim=8, |
||||
**backbone_kwargs |
||||
): |
||||
super().__init__() |
||||
|
||||
# TODO: reduce hard-coded parameters |
||||
DIM = 32 |
||||
MLP_DIM = 2*DIM |
||||
EBD_DIM = DIM |
||||
|
||||
self.backbone = Backbone(in_channels, EBD_DIM, arch=backbone, n_stages=n_stages, **backbone_kwargs) |
||||
|
||||
self.use_tokenizer = use_tokenizer |
||||
if not use_tokenizer: |
||||
# If a tokenzier is not to be used,then downsample the feature maps. |
||||
self.pool_size = pool_size |
||||
self.pool_mode = pool_mode |
||||
self.token_len = pool_size * pool_size |
||||
else: |
||||
self.conv_att = Conv1x1(32, token_len, bias=False) |
||||
self.token_len = token_len |
||||
|
||||
self.enc_with_pos = enc_with_pos |
||||
if enc_with_pos: |
||||
self.enc_pos_embedding = self.create_parameter( |
||||
shape=(1,self.token_len*2,EBD_DIM), |
||||
default_initializer=Normal() |
||||
) |
||||
|
||||
self.enc_depth = enc_depth |
||||
self.dec_depth = dec_depth |
||||
self.enc_head_dim = enc_head_dim |
||||
self.dec_head_dim = dec_head_dim |
||||
|
||||
self.encoder = TransformerEncoder( |
||||
dim=DIM, |
||||
depth=enc_depth, |
||||
n_heads=8, |
||||
head_dim=enc_head_dim, |
||||
mlp_dim=MLP_DIM, |
||||
dropout_rate=0. |
||||
) |
||||
self.decoder = TransformerDecoder( |
||||
dim=DIM, |
||||
depth=dec_depth, |
||||
n_heads=8, |
||||
head_dim=dec_head_dim, |
||||
mlp_dim=MLP_DIM, |
||||
dropout_rate=0., |
||||
apply_softmax=True |
||||
) |
||||
|
||||
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear') |
||||
self.conv_out = nn.Sequential( |
||||
Conv3x3(EBD_DIM, EBD_DIM, norm=True, act=True), |
||||
Conv3x3(EBD_DIM, num_classes) |
||||
) |
||||
|
||||
def _get_semantic_tokens(self, x): |
||||
b, c = paddle.shape(x)[:2] |
||||
att_map = self.conv_att(x) |
||||
att_map = att_map.reshape((b,self.token_len,1,-1)) |
||||
att_map = F.softmax(att_map, axis=-1) |
||||
x = x.reshape((b,1,c,-1)) |
||||
tokens = (x*att_map).sum(-1) |
||||
return tokens |
||||
|
||||
def _get_reshaped_tokens(self, x): |
||||
if self.pool_mode == 'max': |
||||
x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size)) |
||||
elif self.pool_mode == 'avg': |
||||
x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size)) |
||||
else: |
||||
x = x |
||||
tokens = x.transpose((0,2,3,1)).flatten(1,2) |
||||
return tokens |
||||
|
||||
def encode(self, x): |
||||
if self.enc_with_pos: |
||||
x += self.enc_pos_embedding |
||||
x = self.encoder(x) |
||||
return x |
||||
|
||||
def decode(self, x, m): |
||||
b, c, h, w = paddle.shape(x) |
||||
x = x.transpose((0,2,3,1)).flatten(1,2) |
||||
x = self.decoder(x, m) |
||||
x = x.transpose((0,2,1)).reshape((b,c,h,w)) |
||||
return x |
||||
|
||||
def forward(self, t1, t2): |
||||
# Extract features via shared backbone. |
||||
x1 = self.backbone(t1) |
||||
x2 = self.backbone(t2) |
||||
|
||||
# Tokenization |
||||
if self.use_tokenizer: |
||||
token1 = self._get_semantic_tokens(x1) |
||||
token2 = self._get_semantic_tokens(x2) |
||||
else: |
||||
token1 = self._get_reshaped_tokens(x1) |
||||
token2 = self._get_reshaped_tokens(x2) |
||||
|
||||
# Transformer encoder forward |
||||
token = paddle.concat([token1, token2], axis=1) |
||||
token = self.encode(token) |
||||
token1, token2 = paddle.chunk(token, 2, axis=1) |
||||
|
||||
# Transformer decoder forward |
||||
y1 = self.decode(x1, token1) |
||||
y2 = self.decode(x2, token2) |
||||
|
||||
# Feature differencing |
||||
y = paddle.abs(y1 - y2) |
||||
y = self.upsample(y) |
||||
|
||||
# Classifier forward |
||||
pred = self.conv_out(y) |
||||
return pred, |
||||
|
||||
def init_weight(self): |
||||
# Use the default initialization method. |
||||
pass |
||||
|
||||
|
||||
class Residual(nn.Layer): |
||||
def __init__(self, fn): |
||||
super().__init__() |
||||
self.fn = fn |
||||
|
||||
def forward(self, x, **kwargs): |
||||
return self.fn(x, **kwargs) + x |
||||
|
||||
|
||||
class Residual2(nn.Layer): |
||||
def __init__(self, fn): |
||||
super().__init__() |
||||
self.fn = fn |
||||
|
||||
def forward(self, x1, x2, **kwargs): |
||||
return self.fn(x1, x2, **kwargs) + x1 |
||||
|
||||
|
||||
class PreNorm(nn.Layer): |
||||
def __init__(self, dim, fn): |
||||
super().__init__() |
||||
self.norm = nn.LayerNorm(dim) |
||||
self.fn = fn |
||||
|
||||
def forward(self, x, **kwargs): |
||||
return self.fn(self.norm(x), **kwargs) |
||||
|
||||
|
||||
class PreNorm2(nn.Layer): |
||||
def __init__(self, dim, fn): |
||||
super().__init__() |
||||
self.norm = nn.LayerNorm(dim) |
||||
self.fn = fn |
||||
|
||||
def forward(self, x1, x2, **kwargs): |
||||
return self.fn(self.norm(x1), self.norm(x2), **kwargs) |
||||
|
||||
|
||||
class FeedForward(nn.Sequential): |
||||
def __init__(self, dim, hidden_dim, dropout_rate=0.): |
||||
super().__init__( |
||||
nn.Linear(dim, hidden_dim), |
||||
nn.GELU(), |
||||
nn.Dropout(dropout_rate), |
||||
nn.Linear(hidden_dim, dim), |
||||
nn.Dropout(dropout_rate) |
||||
) |
||||
|
||||
|
||||
class CrossAttention(nn.Layer): |
||||
def __init__(self, dim, n_heads=8, head_dim=64, dropout_rate=0., apply_softmax=True): |
||||
super().__init__() |
||||
|
||||
inner_dim = head_dim * n_heads |
||||
self.n_heads = n_heads |
||||
self.scale = dim ** -0.5 |
||||
|
||||
self.apply_softmax = apply_softmax |
||||
|
||||
self.fc_q = nn.Linear(dim, inner_dim, bias_attr=False) |
||||
self.fc_k = nn.Linear(dim, inner_dim, bias_attr=False) |
||||
self.fc_v = nn.Linear(dim, inner_dim, bias_attr=False) |
||||
|
||||
self.fc_out = nn.Sequential( |
||||
nn.Linear(inner_dim, dim), |
||||
nn.Dropout(dropout_rate) |
||||
) |
||||
|
||||
def forward(self, x, ref): |
||||
b, n = paddle.shape(x)[:2] |
||||
h = self.n_heads |
||||
|
||||
q = self.fc_q(x) |
||||
k = self.fc_k(ref) |
||||
v = self.fc_v(ref) |
||||
|
||||
q = q.reshape((b,n,h,-1)).transpose((0,2,1,3)) |
||||
k = k.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3)) |
||||
v = v.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3)) |
||||
|
||||
mult = paddle.matmul(q, k, transpose_y=True) * self.scale |
||||
|
||||
if self.apply_softmax: |
||||
mult = F.softmax(mult, axis=-1) |
||||
|
||||
out = paddle.matmul(mult, v) |
||||
out = out.transpose((0,2,1,3)).flatten(2) |
||||
return self.fc_out(out) |
||||
|
||||
|
||||
class SelfAttention(CrossAttention): |
||||
def forward(self, x): |
||||
return super().forward(x, x) |
||||
|
||||
|
||||
class TransformerEncoder(nn.Layer): |
||||
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate): |
||||
super().__init__() |
||||
self.layers = nn.LayerList([]) |
||||
for _ in range(depth): |
||||
self.layers.append(nn.LayerList([ |
||||
Residual(PreNorm(dim, SelfAttention(dim, n_heads, head_dim, dropout_rate))), |
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))) |
||||
])) |
||||
|
||||
def forward(self, x): |
||||
for att, ff in self.layers: |
||||
x = att(x) |
||||
x = ff(x) |
||||
return x |
||||
|
||||
|
||||
class TransformerDecoder(nn.Layer): |
||||
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate, apply_softmax=True): |
||||
super().__init__() |
||||
self.layers = nn.LayerList([]) |
||||
for _ in range(depth): |
||||
self.layers.append(nn.LayerList([ |
||||
Residual2(PreNorm2(dim, CrossAttention(dim, n_heads, head_dim, dropout_rate, apply_softmax))), |
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))) |
||||
])) |
||||
|
||||
def forward(self, x, m): |
||||
for att, ff in self.layers: |
||||
x = att(x, m) |
||||
x = ff(x) |
||||
return x |
||||
|
||||
|
||||
class Backbone(nn.Layer, KaimingInitMixin): |
||||
def __init__( |
||||
self, |
||||
in_ch, out_ch=32, |
||||
arch='resnet18', |
||||
pretrained=True, |
||||
n_stages=5 |
||||
): |
||||
super().__init__() |
||||
|
||||
expand = 1 |
||||
strides = (2,1,2,1,1) |
||||
if arch == 'resnet18': |
||||
self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) |
||||
elif arch == 'resnet34': |
||||
self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) |
||||
else: |
||||
raise ValueError |
||||
|
||||
self.n_stages = n_stages |
||||
|
||||
if self.n_stages == 5: |
||||
itm_ch = 512 * expand |
||||
elif self.n_stages == 4: |
||||
itm_ch = 256 * expand |
||||
elif self.n_stages == 3: |
||||
itm_ch = 128 * expand |
||||
else: |
||||
raise ValueError |
||||
|
||||
self.upsample = nn.Upsample(scale_factor=2) |
||||
self.conv_out = Conv3x3(itm_ch, out_ch) |
||||
|
||||
self._trim_resnet() |
||||
|
||||
if in_ch != 3: |
||||
self.resnet.conv1 = nn.Conv2D( |
||||
in_ch, |
||||
64, |
||||
kernel_size=7, |
||||
stride=2, |
||||
padding=3, |
||||
bias_attr=False |
||||
) |
||||
|
||||
if not pretrained: |
||||
self.init_weight() |
||||
|
||||
def forward(self, x): |
||||
y = self.resnet.conv1(x) |
||||
y = self.resnet.bn1(y) |
||||
y = self.resnet.relu(y) |
||||
y = self.resnet.maxpool(y) |
||||
|
||||
y = self.resnet.layer1(y) |
||||
y = self.resnet.layer2(y) |
||||
y = self.resnet.layer3(y) |
||||
y = self.resnet.layer4(y) |
||||
|
||||
y = self.upsample(y) |
||||
|
||||
return self.conv_out(y) |
||||
|
||||
def _trim_resnet(self): |
||||
if self.n_stages > 5: |
||||
raise ValueError |
||||
|
||||
if self.n_stages < 5: |
||||
self.resnet.layer4 = Identity() |
||||
|
||||
if self.n_stages <= 3: |
||||
self.resnet.layer3 = Identity() |
||||
|
||||
self.resnet.avgpool = Identity() |
||||
self.resnet.fc = Identity() |
@ -0,0 +1,96 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
from .layers import make_norm, Conv3x3, CBAM |
||||
from .stanet import Backbone, Decoder |
||||
|
||||
|
||||
class DSAMNet(nn.Layer): |
||||
""" |
||||
The DSAMNet implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Q. Shi, et al., "A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing |
||||
Change Detection" |
||||
(https://ieeexplore.ieee.org/document/9467555). |
||||
|
||||
Note that this implementation differs from the original work in two aspects: |
||||
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone. |
||||
2. A classification head is used in place of the original metric learning-based head to stablize the training process. |
||||
|
||||
Args: |
||||
in_channels (int): The number of bands of the input images. |
||||
num_classes (int): The number of target classes. |
||||
ca_ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8. |
||||
sa_kernel (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7. |
||||
""" |
||||
|
||||
def __init__(self, in_channels, num_classes, ca_ratio=8, sa_kernel=7): |
||||
super().__init__() |
||||
|
||||
WIDTH = 64 |
||||
|
||||
self.backbone = Backbone(in_ch=in_channels, arch='resnet18', strides=(1,1,2,2,1)) |
||||
self.decoder = Decoder(WIDTH) |
||||
|
||||
self.cbam1 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel) |
||||
self.cbam2 = CBAM(64, ratio=ca_ratio, kernel_size=sa_kernel) |
||||
|
||||
self.dsl2 = DSLayer(64, num_classes, 32, stride=2, output_padding=1) |
||||
self.dsl3 = DSLayer(128, num_classes, 32, stride=4, output_padding=3) |
||||
|
||||
self.conv_out = nn.Sequential( |
||||
Conv3x3(WIDTH, WIDTH, norm=True, act=True), |
||||
Conv3x3(WIDTH, num_classes) |
||||
) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
f1 = self.backbone(t1) |
||||
f2 = self.backbone(t2) |
||||
|
||||
y1 = self.decoder(f1) |
||||
y2 = self.decoder(f2) |
||||
|
||||
y1 = self.cbam1(y1) |
||||
y2 = self.cbam2(y2) |
||||
|
||||
out = paddle.abs(y1-y2) |
||||
out = F.interpolate(out, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True) |
||||
pred = self.conv_out(out) |
||||
|
||||
ds2 = self.dsl2(paddle.abs(f1[0]-f2[0])) |
||||
ds3 = self.dsl3(paddle.abs(f1[1]-f2[1])) |
||||
|
||||
return pred, ds2, ds3 |
||||
|
||||
def init_weight(self): |
||||
pass |
||||
|
||||
|
||||
class DSLayer(nn.Sequential): |
||||
def __init__(self, in_ch, out_ch, itm_ch, **convd_kwargs): |
||||
super().__init__( |
||||
nn.Conv2DTranspose(in_ch, itm_ch, kernel_size=3, padding=1, **convd_kwargs), |
||||
make_norm(itm_ch), |
||||
nn.ReLU(), |
||||
nn.Dropout2D(p=0.2), |
||||
nn.Conv2DTranspose(itm_ch, out_ch, kernel_size=3, padding=1) |
||||
) |
@ -0,0 +1,209 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
from paddle.vision.models import vgg16 |
||||
|
||||
|
||||
from .layers import Conv1x1, make_norm, ChannelAttention, SpatialAttention |
||||
|
||||
|
||||
class DSIFN(nn.Layer): |
||||
""" |
||||
The DSIFN implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
C. Zhang, et al., "A deeply supervised image fusion network for change detection in high resolution bi-temporal remote |
||||
sensing images" |
||||
(https://www.sciencedirect.com/science/article/pii/S0924271620301532). |
||||
|
||||
Note that in this implementation, there is a flexible number of target classes. |
||||
|
||||
Args: |
||||
num_classes (int): The number of target classes. |
||||
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained |
||||
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. |
||||
""" |
||||
|
||||
def __init__(self, num_classes, use_dropout=False): |
||||
super().__init__() |
||||
|
||||
self.encoder1 = self.encoder2 = VGG16FeaturePicker() |
||||
|
||||
self.sa1 = SpatialAttention() |
||||
self.sa2= SpatialAttention() |
||||
self.sa3 = SpatialAttention() |
||||
self.sa4 = SpatialAttention() |
||||
self.sa5 = SpatialAttention() |
||||
|
||||
self.ca1 = ChannelAttention(in_ch=1024) |
||||
self.bn_ca1 = make_norm(1024) |
||||
self.o1_conv1 = conv2d_bn(1024, 512, use_dropout) |
||||
self.o1_conv2 = conv2d_bn(512, 512, use_dropout) |
||||
self.bn_sa1 = make_norm(512) |
||||
self.o1_conv3 = Conv1x1(512, num_classes) |
||||
self.trans_conv1 = nn.Conv2DTranspose(512, 512, kernel_size=2, stride=2) |
||||
|
||||
self.ca2 = ChannelAttention(in_ch=1536) |
||||
self.bn_ca2 = make_norm(1536) |
||||
self.o2_conv1 = conv2d_bn(1536, 512, use_dropout) |
||||
self.o2_conv2 = conv2d_bn(512, 256, use_dropout) |
||||
self.o2_conv3 = conv2d_bn(256, 256, use_dropout) |
||||
self.bn_sa2 = make_norm(256) |
||||
self.o2_conv4 = Conv1x1(256, num_classes) |
||||
self.trans_conv2 = nn.Conv2DTranspose(256, 256, kernel_size=2, stride=2) |
||||
|
||||
self.ca3 = ChannelAttention(in_ch=768) |
||||
self.o3_conv1 = conv2d_bn(768, 256, use_dropout) |
||||
self.o3_conv2 = conv2d_bn(256, 128, use_dropout) |
||||
self.o3_conv3 = conv2d_bn(128, 128, use_dropout) |
||||
self.bn_sa3 = make_norm(128) |
||||
self.o3_conv4 = Conv1x1(128, num_classes) |
||||
self.trans_conv3 = nn.Conv2DTranspose(128, 128, kernel_size=2, stride=2) |
||||
|
||||
self.ca4 = ChannelAttention(in_ch=384) |
||||
self.o4_conv1 = conv2d_bn(384, 128, use_dropout) |
||||
self.o4_conv2 = conv2d_bn(128, 64, use_dropout) |
||||
self.o4_conv3 = conv2d_bn(64, 64, use_dropout) |
||||
self.bn_sa4 = make_norm(64) |
||||
self.o4_conv4 = Conv1x1(64, num_classes) |
||||
self.trans_conv4 = nn.Conv2DTranspose(64, 64, kernel_size=2, stride=2) |
||||
|
||||
self.ca5 = ChannelAttention(in_ch=192) |
||||
self.o5_conv1 = conv2d_bn(192, 64, use_dropout) |
||||
self.o5_conv2 = conv2d_bn(64, 32, use_dropout) |
||||
self.o5_conv3 = conv2d_bn(32, 16, use_dropout) |
||||
self.bn_sa5 = make_norm(16) |
||||
self.o5_conv4 = Conv1x1(16, num_classes) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
# Extract bi-temporal features. |
||||
with paddle.no_grad(): |
||||
self.encoder1.eval(), self.encoder2.eval() |
||||
t1_feats = self.encoder1(t1) |
||||
t2_feats = self.encoder2(t2) |
||||
|
||||
t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_feats |
||||
t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29,= t2_feats |
||||
|
||||
# Multi-level decoding |
||||
x = paddle.concat([t1_f_l29, t2_f_l29], axis=1) |
||||
x = self.o1_conv1(x) |
||||
x = self.o1_conv2(x) |
||||
x = self.sa1(x) * x |
||||
x = self.bn_sa1(x) |
||||
|
||||
out1 = F.interpolate( |
||||
self.o1_conv3(x), |
||||
size=paddle.shape(t1)[2:], |
||||
mode='bilinear', |
||||
align_corners=True |
||||
) |
||||
|
||||
x = self.trans_conv1(x) |
||||
x = paddle.concat([x, t1_f_l22, t2_f_l22], axis=1) |
||||
x = self.ca2(x)*x |
||||
x = self.o2_conv1(x) |
||||
x = self.o2_conv2(x) |
||||
x = self.o2_conv3(x) |
||||
x = self.sa2(x) *x |
||||
x = self.bn_sa2(x) |
||||
|
||||
out2 = F.interpolate( |
||||
self.o2_conv4(x), |
||||
size=paddle.shape(t1)[2:], |
||||
mode='bilinear', |
||||
align_corners=True |
||||
) |
||||
|
||||
x = self.trans_conv2(x) |
||||
x = paddle.concat([x, t1_f_l15, t2_f_l15], axis=1) |
||||
x = self.ca3(x)*x |
||||
x = self.o3_conv1(x) |
||||
x = self.o3_conv2(x) |
||||
x = self.o3_conv3(x) |
||||
x = self.sa3(x) *x |
||||
x = self.bn_sa3(x) |
||||
|
||||
out3 = F.interpolate( |
||||
self.o3_conv4(x), |
||||
size=paddle.shape(t1)[2:], |
||||
mode='bilinear', |
||||
align_corners=True |
||||
) |
||||
|
||||
x = self.trans_conv3(x) |
||||
x = paddle.concat([x, t1_f_l8, t2_f_l8], axis=1) |
||||
x = self.ca4(x)*x |
||||
x = self.o4_conv1(x) |
||||
x = self.o4_conv2(x) |
||||
x = self.o4_conv3(x) |
||||
x = self.sa4(x) *x |
||||
x = self.bn_sa4(x) |
||||
|
||||
out4 = F.interpolate( |
||||
self.o4_conv4(x), |
||||
size=paddle.shape(t1)[2:], |
||||
mode='bilinear', |
||||
align_corners=True |
||||
) |
||||
|
||||
x = self.trans_conv4(x) |
||||
x = paddle.concat([x, t1_f_l3, t2_f_l3], axis=1) |
||||
x = self.ca5(x)*x |
||||
x = self.o5_conv1(x) |
||||
x = self.o5_conv2(x) |
||||
x = self.o5_conv3(x) |
||||
x = self.sa5(x) *x |
||||
x = self.bn_sa5(x) |
||||
|
||||
out5 = self.o5_conv4(x) |
||||
|
||||
return out5, out4, out3, out2, out1 |
||||
|
||||
def init_weight(self): |
||||
# Do nothing |
||||
pass |
||||
|
||||
|
||||
class VGG16FeaturePicker(nn.Layer): |
||||
def __init__(self, indices=(3,8,15,22,29)): |
||||
super().__init__() |
||||
features = list(vgg16(pretrained=True).features)[:30] |
||||
self.features = nn.LayerList(features) |
||||
self.features.eval() |
||||
self.indices = set(indices) |
||||
|
||||
def forward(self, x): |
||||
picked_feats = [] |
||||
for idx, model in enumerate(self.features): |
||||
x = model(x) |
||||
if idx in self.indices: |
||||
picked_feats.append(x) |
||||
return picked_feats |
||||
|
||||
|
||||
def conv2d_bn(in_ch, out_ch, with_dropout=True): |
||||
lst = [ |
||||
nn.Conv2D(in_ch, out_ch, kernel_size=3, stride=1, padding=1), |
||||
nn.PReLU(), |
||||
make_norm(out_ch), |
||||
] |
||||
if with_dropout: |
||||
lst.append(nn.Dropout(p=0.6)) |
||||
return nn.Sequential(*lst) |
@ -0,0 +1,16 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from .blocks import * |
||||
from .attention import ChannelAttention, SpatialAttention, CBAM |
@ -0,0 +1,96 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
from .blocks import Conv1x1, BasicConv |
||||
|
||||
|
||||
class ChannelAttention(nn.Layer): |
||||
""" |
||||
The channel attention module implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module" |
||||
(https://arxiv.org/abs/1807.06521). |
||||
|
||||
Args: |
||||
in_ch (int): The number of channels of the input features. |
||||
ratio (int, optional): The channel reduction ratio. Default: 8. |
||||
""" |
||||
|
||||
def __init__(self, in_ch, ratio=8): |
||||
super().__init__() |
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1) |
||||
self.max_pool = nn.AdaptiveMaxPool2D(1) |
||||
self.fc1 = Conv1x1(in_ch, in_ch//ratio, bias=False, act=True) |
||||
self.fc2 = Conv1x1(in_ch//ratio, in_ch, bias=False) |
||||
|
||||
def forward(self,x): |
||||
avg_out = self.fc2(self.fc1(self.avg_pool(x))) |
||||
max_out = self.fc2(self.fc1(self.max_pool(x))) |
||||
out = avg_out + max_out |
||||
return F.sigmoid(out) |
||||
|
||||
|
||||
class SpatialAttention(nn.Layer): |
||||
""" |
||||
The spatial attention module implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module" |
||||
(https://arxiv.org/abs/1807.06521). |
||||
|
||||
Args: |
||||
kernel_size (int, optional): The size of the convolutional kernel. Default: 7. |
||||
""" |
||||
|
||||
def __init__(self, kernel_size=7): |
||||
super().__init__() |
||||
self.conv = BasicConv(2, 1, kernel_size, bias=False) |
||||
|
||||
def forward(self, x): |
||||
avg_out = paddle.mean(x, axis=1, keepdim=True) |
||||
max_out = paddle.max(x, axis=1, keepdim=True) |
||||
x = paddle.concat([avg_out, max_out], axis=1) |
||||
x = self.conv(x) |
||||
return F.sigmoid(x) |
||||
|
||||
|
||||
class CBAM(nn.Layer): |
||||
""" |
||||
The CBAM implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Sanghyun Woo, et al., "CBAM: Convolutional Block Attention Module" |
||||
(https://arxiv.org/abs/1807.06521). |
||||
|
||||
Args: |
||||
in_ch (int): The number of channels of the input features. |
||||
ratio (int, optional): The channel reduction ratio for the channel attention module. Default: 8. |
||||
kernel_size (int, optional): The size of the convolutional kernel used in the spatial attention module. Default: 7. |
||||
""" |
||||
|
||||
def __init__(self, in_ch, ratio=8, kernel_size=7): |
||||
super().__init__() |
||||
self.ca = ChannelAttention(in_ch, ratio=ratio) |
||||
self.sa = SpatialAttention(kernel_size=kernel_size) |
||||
|
||||
def forward(self, x): |
||||
y = self.ca(x) * x |
||||
y = self.sa(y) * y |
||||
return y |
@ -0,0 +1,142 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle.nn as nn |
||||
|
||||
|
||||
__all__ = [ |
||||
'BasicConv', 'Conv1x1', 'Conv3x3', 'Conv7x7', |
||||
'MaxPool2x2', 'MaxUnPool2x2', |
||||
'ConvTransposed3x3', |
||||
'Identity', |
||||
'get_norm_layer', 'get_act_layer', |
||||
'make_norm', 'make_act' |
||||
] |
||||
|
||||
|
||||
def get_norm_layer(): |
||||
# TODO: select appropriate norm layer. |
||||
return nn.BatchNorm2D |
||||
|
||||
|
||||
def get_act_layer(): |
||||
# TODO: select appropriate activation layer. |
||||
return nn.ReLU |
||||
|
||||
|
||||
def make_norm(*args, **kwargs): |
||||
norm_layer = get_norm_layer() |
||||
return norm_layer(*args, **kwargs) |
||||
|
||||
|
||||
def make_act(*args, **kwargs): |
||||
act_layer = get_act_layer() |
||||
return act_layer(*args, **kwargs) |
||||
|
||||
|
||||
class BasicConv(nn.Layer): |
||||
def __init__( |
||||
self, in_ch, out_ch, |
||||
kernel_size, pad_mode='constant', |
||||
bias='auto', norm=False, act=False, |
||||
**kwargs |
||||
): |
||||
super().__init__() |
||||
seq = [] |
||||
if kernel_size >= 2: |
||||
seq.append(nn.Pad2D(kernel_size//2, mode=pad_mode)) |
||||
seq.append( |
||||
nn.Conv2D( |
||||
in_ch, out_ch, kernel_size, |
||||
stride=1, padding=0, |
||||
bias_attr=(False if norm else None) if bias=='auto' else bias, |
||||
**kwargs |
||||
) |
||||
) |
||||
if norm: |
||||
if norm is True: |
||||
norm = make_norm(out_ch) |
||||
seq.append(norm) |
||||
if act: |
||||
if act is True: |
||||
act = make_act() |
||||
seq.append(act) |
||||
self.seq = nn.Sequential(*seq) |
||||
|
||||
def forward(self, x): |
||||
return self.seq(x) |
||||
|
||||
|
||||
class Conv1x1(BasicConv): |
||||
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs): |
||||
super().__init__(in_ch, out_ch, 1, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs) |
||||
|
||||
|
||||
class Conv3x3(BasicConv): |
||||
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs): |
||||
super().__init__(in_ch, out_ch, 3, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs) |
||||
|
||||
|
||||
class Conv7x7(BasicConv): |
||||
def __init__(self, in_ch, out_ch, pad_mode='constant', bias='auto', norm=False, act=False, **kwargs): |
||||
super().__init__(in_ch, out_ch, 7, pad_mode=pad_mode, bias=bias, norm=norm, act=act, **kwargs) |
||||
|
||||
|
||||
class MaxPool2x2(nn.MaxPool2D): |
||||
def __init__(self, **kwargs): |
||||
super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs) |
||||
|
||||
|
||||
class MaxUnPool2x2(nn.MaxUnPool2D): |
||||
def __init__(self, **kwargs): |
||||
super().__init__(kernel_size=2, stride=(2,2), padding=(0,0), **kwargs) |
||||
|
||||
|
||||
class ConvTransposed3x3(nn.Layer): |
||||
def __init__( |
||||
self, in_ch, out_ch, |
||||
bias='auto', norm=False, act=False, |
||||
**kwargs |
||||
): |
||||
super().__init__() |
||||
seq = [] |
||||
seq.append( |
||||
nn.Conv2DTranspose( |
||||
in_ch, out_ch, 3, |
||||
stride=2, padding=1, |
||||
bias_attr=(False if norm else None) if bias=='auto' else bias, |
||||
**kwargs |
||||
) |
||||
) |
||||
if norm: |
||||
if norm is True: |
||||
norm = make_norm(out_ch) |
||||
seq.append(norm) |
||||
if act: |
||||
if act is True: |
||||
act = make_act() |
||||
seq.append(act) |
||||
self.seq = nn.Sequential(*seq) |
||||
|
||||
def forward(self, x): |
||||
return self.seq(x) |
||||
|
||||
|
||||
class Identity(nn.Layer): |
||||
"""A placeholder identity operator that accepts exactly one argument.""" |
||||
def __init__(self, *args, **kwargs): |
||||
super().__init__() |
||||
|
||||
def forward(self, x): |
||||
return x |
@ -0,0 +1,86 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
def normal_init(param, *args, **kwargs): |
||||
""" |
||||
Initialize parameters with a normal distribution. |
||||
|
||||
Args: |
||||
param (Tensor): The tensor that needs to be initialized. |
||||
|
||||
Returns: |
||||
The initialized parameters. |
||||
""" |
||||
|
||||
return nn.initializer.Normal(*args, **kwargs)(param) |
||||
|
||||
|
||||
def kaiming_normal_init(param, *args, **kwargs): |
||||
""" |
||||
Initialize parameters with the Kaiming normal distribution. |
||||
|
||||
For more information about the Kaiming initialization method, please refer to |
||||
https://arxiv.org/abs/1502.01852 |
||||
|
||||
Args: |
||||
param (Tensor): The tensor that needs to be initialized. |
||||
|
||||
Returns: |
||||
The initialized parameters. |
||||
""" |
||||
|
||||
return nn.initializer.KaimingNormal(*args, **kwargs)(param) |
||||
|
||||
|
||||
def constant_init(param, *args, **kwargs): |
||||
""" |
||||
Initialize parameters with constants. |
||||
|
||||
Args: |
||||
param (Tensor): The tensor that needs to be initialized. |
||||
|
||||
Returns: |
||||
The initialized parameters. |
||||
""" |
||||
|
||||
return nn.initializer.Constant(*args, **kwargs)(param) |
||||
|
||||
|
||||
class KaimingInitMixin: |
||||
""" |
||||
A mix-in that provides the Kaiming initialization functionality. |
||||
|
||||
Examples: |
||||
|
||||
from paddlers.models.cd.models.param_init import KaimingInitMixin |
||||
|
||||
class CustomNet(nn.Layer, KaimingInitMixin): |
||||
def __init__(self, num_channels, num_classes): |
||||
super().__init__() |
||||
self.conv = nn.Conv2D(num_channels, num_classes, 3, 1, 0, bias_attr=False) |
||||
self.bn = nn.BatchNorm2D(num_classes) |
||||
self.init_weight() |
||||
""" |
||||
|
||||
def init_weight(self): |
||||
for layer in self.sublayers(): |
||||
if isinstance(layer, nn.Conv2D): |
||||
kaiming_normal_init(layer.weight) |
||||
elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): |
||||
constant_init(layer.weight, value=1) |
||||
constant_init(layer.bias, value=0) |
@ -0,0 +1,155 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
from .layers import Conv1x1, MaxPool2x2, make_norm, ChannelAttention |
||||
from .param_init import KaimingInitMixin |
||||
|
||||
|
||||
class SNUNet(nn.Layer, KaimingInitMixin): |
||||
""" |
||||
The SNUNet implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
S. Fang, et al., "SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images" |
||||
(https://ieeexplore.ieee.org/document/9355573). |
||||
|
||||
Note that bilinear interpolation is adopted as the upsampling method, which is different from the paper. |
||||
|
||||
Args: |
||||
in_channels (int): The number of bands of the input images. |
||||
num_classes (int): The number of target classes. |
||||
width (int, optional): The output channels of the first convolutional layer. Default: 32. |
||||
""" |
||||
|
||||
def __init__(self, in_channels, num_classes, width=32): |
||||
super().__init__() |
||||
|
||||
filters = (width, width*2, width*4, width*8, width*16) |
||||
|
||||
self.conv0_0 = ConvBlockNested(in_channels, filters[0], filters[0]) |
||||
self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1]) |
||||
self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2]) |
||||
self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3]) |
||||
self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4]) |
||||
self.down1 = MaxPool2x2() |
||||
self.down2 = MaxPool2x2() |
||||
self.down3 = MaxPool2x2() |
||||
self.down4 = MaxPool2x2() |
||||
self.up1_0 = Up(filters[1]) |
||||
self.up2_0 = Up(filters[2]) |
||||
self.up3_0 = Up(filters[3]) |
||||
self.up4_0 = Up(filters[4]) |
||||
|
||||
self.conv0_1 = ConvBlockNested(filters[0]*2+filters[1], filters[0], filters[0]) |
||||
self.conv1_1 = ConvBlockNested(filters[1]*2+filters[2], filters[1], filters[1]) |
||||
self.conv2_1 = ConvBlockNested(filters[2]*2+filters[3], filters[2], filters[2]) |
||||
self.conv3_1 = ConvBlockNested(filters[3]*2+filters[4], filters[3], filters[3]) |
||||
self.up1_1 = Up(filters[1]) |
||||
self.up2_1 = Up(filters[2]) |
||||
self.up3_1 = Up(filters[3]) |
||||
|
||||
self.conv0_2 = ConvBlockNested(filters[0]*3+filters[1], filters[0], filters[0]) |
||||
self.conv1_2 = ConvBlockNested(filters[1]*3+filters[2], filters[1], filters[1]) |
||||
self.conv2_2 = ConvBlockNested(filters[2]*3+filters[3], filters[2], filters[2]) |
||||
self.up1_2 = Up(filters[1]) |
||||
self.up2_2 = Up(filters[2]) |
||||
|
||||
self.conv0_3 = ConvBlockNested(filters[0]*4+filters[1], filters[0], filters[0]) |
||||
self.conv1_3 = ConvBlockNested(filters[1]*4+filters[2], filters[1], filters[1]) |
||||
self.up1_3 = Up(filters[1]) |
||||
|
||||
self.conv0_4 = ConvBlockNested(filters[0]*5+filters[1], filters[0], filters[0]) |
||||
|
||||
self.ca_intra = ChannelAttention(filters[0], ratio=4) |
||||
self.ca_inter = ChannelAttention(filters[0]*4, ratio=16) |
||||
|
||||
self.conv_out = Conv1x1(filters[0]*4, num_classes) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
x0_0_t1 = self.conv0_0(t1) |
||||
x1_0_t1 = self.conv1_0(self.down1(x0_0_t1)) |
||||
x2_0_t1 = self.conv2_0(self.down2(x1_0_t1)) |
||||
x3_0_t1 = self.conv3_0(self.down3(x2_0_t1)) |
||||
|
||||
x0_0_t2 = self.conv0_0(t2) |
||||
x1_0_t2 = self.conv1_0(self.down1(x0_0_t2)) |
||||
x2_0_t2 = self.conv2_0(self.down2(x1_0_t2)) |
||||
x3_0_t2 = self.conv3_0(self.down3(x2_0_t2)) |
||||
x4_0_t2 = self.conv4_0(self.down4(x3_0_t2)) |
||||
|
||||
x0_1 = self.conv0_1(paddle.concat([x0_0_t1, x0_0_t2, self.up1_0(x1_0_t2)], 1)) |
||||
x1_1 = self.conv1_1(paddle.concat([x1_0_t1, x1_0_t2, self.up2_0(x2_0_t2)], 1)) |
||||
x0_2 = self.conv0_2(paddle.concat([x0_0_t1, x0_0_t2, x0_1, self.up1_1(x1_1)], 1)) |
||||
|
||||
x2_1 = self.conv2_1(paddle.concat([x2_0_t1, x2_0_t2, self.up3_0(x3_0_t2)], 1)) |
||||
x1_2 = self.conv1_2(paddle.concat([x1_0_t1, x1_0_t2, x1_1, self.up2_1(x2_1)], 1)) |
||||
x0_3 = self.conv0_3(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, self.up1_2(x1_2)], 1)) |
||||
|
||||
x3_1 = self.conv3_1(paddle.concat([x3_0_t1, x3_0_t2, self.up4_0(x4_0_t2)], 1)) |
||||
x2_2 = self.conv2_2(paddle.concat([x2_0_t1, x2_0_t2, x2_1, self.up3_1(x3_1)], 1)) |
||||
x1_3 = self.conv1_3(paddle.concat([x1_0_t1, x1_0_t2, x1_1, x1_2, self.up2_2(x2_2)], 1)) |
||||
x0_4 = self.conv0_4(paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1)) |
||||
|
||||
out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1) |
||||
|
||||
intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0) |
||||
m_intra = self.ca_intra(intra) |
||||
out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1,4,1,1))) |
||||
|
||||
pred = self.conv_out(out) |
||||
return pred, |
||||
|
||||
|
||||
class ConvBlockNested(nn.Layer): |
||||
def __init__(self, in_ch, out_ch, mid_ch): |
||||
super().__init__() |
||||
self.act = nn.ReLU() |
||||
self.conv1 = nn.Conv2D(in_ch, mid_ch, kernel_size=3, padding=1) |
||||
self.bn1 = make_norm(mid_ch) |
||||
self.conv2 = nn.Conv2D(mid_ch, out_ch, kernel_size=3, padding=1) |
||||
self.bn2 = make_norm(out_ch) |
||||
|
||||
def forward(self, x): |
||||
x = self.conv1(x) |
||||
identity = x |
||||
x = self.bn1(x) |
||||
x = self.act(x) |
||||
|
||||
x = self.conv2(x) |
||||
x = self.bn2(x) |
||||
output = self.act(x + identity) |
||||
return output |
||||
|
||||
|
||||
class Up(nn.Layer): |
||||
def __init__(self, in_ch, use_conv=False): |
||||
super().__init__() |
||||
if use_conv: |
||||
self.up = nn.Conv2DTranspose(in_ch, in_ch, 2, stride=2) |
||||
else: |
||||
self.up = nn.Upsample(scale_factor=2, |
||||
mode='bilinear', |
||||
align_corners=True) |
||||
|
||||
def forward(self, x): |
||||
x = self.up(x) |
||||
return x |
@ -0,0 +1,298 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
from .backbones import resnet |
||||
from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity |
||||
from .param_init import KaimingInitMixin |
||||
|
||||
|
||||
class STANet(nn.Layer): |
||||
""" |
||||
The STANet implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection" |
||||
(https://www.mdpi.com/2072-4292/12/10/1662). |
||||
|
||||
Note that this implementation differs from the original work in two aspects: |
||||
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone. |
||||
2. A classification head is used in place of the original metric learning-based head to stablize the training process. |
||||
|
||||
Args: |
||||
in_channels (int): The number of bands of the input images. |
||||
num_classes (int): The number of target classes. |
||||
att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'. |
||||
ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values |
||||
greater than 1, the input features will first be processed by an average pooling layer with the kernel size of |
||||
`ds_factor`, before being used to calculate the attention scores. Default: 1. |
||||
|
||||
Raises: |
||||
ValueError: When `att_type` has an illeagal value (unsupported attention type). |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
in_channels, |
||||
num_classes, |
||||
att_type='BAM', |
||||
ds_factor=1 |
||||
): |
||||
super().__init__() |
||||
|
||||
WIDTH = 64 |
||||
|
||||
self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH) |
||||
self.attend = build_sta_module(in_ch=WIDTH, att_type=att_type, ds=ds_factor) |
||||
self.conv_out = nn.Sequential( |
||||
Conv3x3(WIDTH, WIDTH, norm=True, act=True), |
||||
Conv3x3(WIDTH, num_classes) |
||||
) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
f1 = self.extract(t1) |
||||
f2 = self.extract(t2) |
||||
|
||||
f1, f2 = self.attend(f1, f2) |
||||
|
||||
y = paddle.abs(f1- f2) |
||||
y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True) |
||||
|
||||
pred = self.conv_out(y) |
||||
return pred, |
||||
|
||||
def init_weight(self): |
||||
# Do nothing here as the encoder and decoder weights have already been initialized. |
||||
# Note however that currently self.attend and self.conv_out use the default initilization method. |
||||
pass |
||||
|
||||
|
||||
def build_feat_extractor(in_ch, width): |
||||
return nn.Sequential( |
||||
Backbone(in_ch, 'resnet18'), |
||||
Decoder(width) |
||||
) |
||||
|
||||
|
||||
def build_sta_module(in_ch, att_type, ds): |
||||
if att_type == 'BAM': |
||||
return Attention(BAM(in_ch, ds)) |
||||
elif att_type == 'PAM': |
||||
return Attention(PAM(in_ch, ds)) |
||||
else: |
||||
raise ValueError |
||||
|
||||
|
||||
class Backbone(nn.Layer, KaimingInitMixin): |
||||
def __init__(self, in_ch, arch, pretrained=True, strides=(2,1,2,2,2)): |
||||
super().__init__() |
||||
|
||||
if arch == 'resnet18': |
||||
self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) |
||||
elif arch == 'resnet34': |
||||
self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) |
||||
elif arch == 'resnet50': |
||||
self.resnet = resnet.resnet50(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer()) |
||||
else: |
||||
raise ValueError |
||||
|
||||
self._trim_resnet() |
||||
|
||||
if in_ch != 3: |
||||
self.resnet.conv1 = nn.Conv2D( |
||||
in_ch, |
||||
64, |
||||
kernel_size=7, |
||||
stride=strides[0], |
||||
padding=3, |
||||
bias_attr=False |
||||
) |
||||
|
||||
if not pretrained: |
||||
self.init_weight() |
||||
|
||||
def forward(self, x): |
||||
x = self.resnet.conv1(x) |
||||
x = self.resnet.bn1(x) |
||||
x = self.resnet.relu(x) |
||||
x = self.resnet.maxpool(x) |
||||
|
||||
x1 = self.resnet.layer1(x) |
||||
x2 = self.resnet.layer2(x1) |
||||
x3 = self.resnet.layer3(x2) |
||||
x4 = self.resnet.layer4(x3) |
||||
|
||||
return x1, x2, x3, x4 |
||||
|
||||
def _trim_resnet(self): |
||||
self.resnet.avgpool = Identity() |
||||
self.resnet.fc = Identity() |
||||
|
||||
|
||||
class Decoder(nn.Layer, KaimingInitMixin): |
||||
def __init__(self, f_ch): |
||||
super().__init__() |
||||
self.dr1 = Conv1x1(64, 96, norm=True, act=True) |
||||
self.dr2 = Conv1x1(128, 96, norm=True, act=True) |
||||
self.dr3 = Conv1x1(256, 96, norm=True, act=True) |
||||
self.dr4 = Conv1x1(512, 96, norm=True, act=True) |
||||
self.conv_out = nn.Sequential( |
||||
Conv3x3(384, 256, norm=True, act=True), |
||||
nn.Dropout(0.5), |
||||
Conv1x1(256, f_ch, norm=True, act=True) |
||||
) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, feats): |
||||
f1 = self.dr1(feats[0]) |
||||
f2 = self.dr2(feats[1]) |
||||
f3 = self.dr3(feats[2]) |
||||
f4 = self.dr4(feats[3]) |
||||
|
||||
f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
||||
f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
||||
f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
||||
|
||||
x = paddle.concat([f1, f2, f3, f4], axis=1) |
||||
y = self.conv_out(x) |
||||
|
||||
return y |
||||
|
||||
|
||||
class BAM(nn.Layer): |
||||
def __init__(self, in_ch, ds): |
||||
super().__init__() |
||||
|
||||
self.ds = ds |
||||
self.pool = nn.AvgPool2D(self.ds) |
||||
|
||||
self.val_ch = in_ch |
||||
self.key_ch = in_ch // 8 |
||||
self.conv_q = Conv1x1(in_ch, self.key_ch) |
||||
self.conv_k = Conv1x1(in_ch, self.key_ch) |
||||
self.conv_v = Conv1x1(in_ch, self.val_ch) |
||||
|
||||
self.softmax = nn.Softmax(axis=-1) |
||||
|
||||
def forward(self, x): |
||||
x = x.flatten(-2) |
||||
x_rs = self.pool(x) |
||||
|
||||
b, c, h, w = paddle.shape(x_rs) |
||||
query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1)) |
||||
key = self.conv_k(x_rs).reshape((b,-1,h*w)) |
||||
energy = paddle.bmm(query, key) |
||||
energy = (self.key_ch**(-0.5)) * energy |
||||
|
||||
attention = self.softmax(energy) |
||||
|
||||
value = self.conv_v(x_rs).reshape((b,-1,w*h)) |
||||
|
||||
out = paddle.bmm(value, attention.transpose((0,2,1))) |
||||
out = out.reshape((b,c,h,w)) |
||||
|
||||
out = F.interpolate(out, scale_factor=self.ds) |
||||
out = out + x |
||||
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2]) |
||||
|
||||
|
||||
class PAMBlock(nn.Layer): |
||||
def __init__(self, in_ch, scale=1, ds=1): |
||||
super().__init__() |
||||
|
||||
self.scale = scale |
||||
self.ds = ds |
||||
self.pool = nn.AvgPool2D(self.ds) |
||||
|
||||
self.val_ch = in_ch |
||||
self.key_ch = in_ch // 8 |
||||
self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True) |
||||
self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True) |
||||
self.conv_v = Conv1x1(in_ch, self.val_ch) |
||||
|
||||
def forward(self, x): |
||||
x_rs = self.pool(x) |
||||
|
||||
# Get query, key, and value. |
||||
query = self.conv_q(x_rs) |
||||
key = self.conv_k(x_rs) |
||||
value = self.conv_v(x_rs) |
||||
|
||||
# Split the whole image into subregions. |
||||
b, c, h, w = paddle.shape(x_rs) |
||||
query = self._split_subregions(query) |
||||
key = self._split_subregions(key) |
||||
value = self._split_subregions(value) |
||||
|
||||
# Perform subregion-wise attention. |
||||
out = self._attend(query, key, value) |
||||
|
||||
# Stack subregions to reconstruct the whole image. |
||||
out = self._recons_whole(out, b, c, h, w) |
||||
out = F.interpolate(out, scale_factor=self.ds) |
||||
return out |
||||
|
||||
def _attend(self, query, key, value): |
||||
energy = paddle.bmm(query.transpose((0,2,1)), key) # batch matrix multiplication |
||||
energy = (self.key_ch**(-0.5)) * energy |
||||
attention = F.softmax(energy, axis=-1) |
||||
out = paddle.bmm(value, attention.transpose((0,2,1))) |
||||
return out |
||||
|
||||
def _split_subregions(self, x): |
||||
b, c, h, w = paddle.shape(x) |
||||
assert h % self.scale == 0 and w % self.scale == 0 |
||||
x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale)) |
||||
x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1)) |
||||
return x |
||||
|
||||
def _recons_whole(self, x, b, c, h, w): |
||||
x = x.reshape((b, self.scale, self.scale, c, h//self.scale, w//self.scale)) |
||||
x = x.transpose((0,3,1,4,2,5)).reshape((b, c, h, w)) |
||||
return x |
||||
|
||||
|
||||
class PAM(nn.Layer): |
||||
def __init__(self, in_ch, ds, scales=(1,2,4,8)): |
||||
super().__init__() |
||||
|
||||
self.stages = nn.LayerList([ |
||||
PAMBlock(in_ch, scale=s, ds=ds) |
||||
for s in scales |
||||
]) |
||||
self.conv_out = Conv1x1(in_ch*len(scales), in_ch, bias=False) |
||||
|
||||
def forward(self, x): |
||||
x = x.flatten(-2) |
||||
res = [stage(x) for stage in self.stages] |
||||
out = self.conv_out(paddle.concat(res, axis=1)) |
||||
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2]) |
||||
|
||||
|
||||
class Attention(nn.Layer): |
||||
def __init__(self, att): |
||||
super().__init__() |
||||
self.att = att |
||||
|
||||
def forward(self, x1, x2): |
||||
x = paddle.stack([x1, x2], axis=-1) |
||||
y = self.att(x) |
||||
return y[...,0], y[...,1] |
@ -0,0 +1,201 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity |
||||
from .param_init import normal_init, constant_init |
||||
|
||||
|
||||
class UNetEarlyFusion(nn.Layer): |
||||
""" |
||||
The FC-EF implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection" |
||||
(https://arxiv.org/abs/1810.08462). |
||||
|
||||
Args: |
||||
in_channels (int): The number of bands of the input images. |
||||
num_classes (int): The number of target classes. |
||||
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained |
||||
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
in_channels, |
||||
num_classes, |
||||
use_dropout=False |
||||
): |
||||
super().__init__() |
||||
|
||||
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256 |
||||
|
||||
self.use_dropout = use_dropout |
||||
|
||||
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True) |
||||
self.do11 = self._make_dropout() |
||||
self.conv12 = Conv3x3(C1, C1, norm=True, act=True) |
||||
self.do12 = self._make_dropout() |
||||
self.pool1 = MaxPool2x2() |
||||
|
||||
self.conv21 = Conv3x3(C1, C2, norm=True, act=True) |
||||
self.do21 = self._make_dropout() |
||||
self.conv22 = Conv3x3(C2, C2, norm=True, act=True) |
||||
self.do22 = self._make_dropout() |
||||
self.pool2 = MaxPool2x2() |
||||
|
||||
self.conv31 = Conv3x3(C2, C3, norm=True, act=True) |
||||
self.do31 = self._make_dropout() |
||||
self.conv32 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32 = self._make_dropout() |
||||
self.conv33 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do33 = self._make_dropout() |
||||
self.pool3 = MaxPool2x2() |
||||
|
||||
self.conv41 = Conv3x3(C3, C4, norm=True, act=True) |
||||
self.do41 = self._make_dropout() |
||||
self.conv42 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42 = self._make_dropout() |
||||
self.conv43 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do43 = self._make_dropout() |
||||
self.pool4 = MaxPool2x2() |
||||
|
||||
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) |
||||
|
||||
self.conv43d = Conv3x3(C5, C4, norm=True, act=True) |
||||
self.do43d = self._make_dropout() |
||||
self.conv42d = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42d = self._make_dropout() |
||||
self.conv41d = Conv3x3(C4, C3, norm=True, act=True) |
||||
self.do41d = self._make_dropout() |
||||
|
||||
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) |
||||
|
||||
self.conv33d = Conv3x3(C4, C3, norm=True, act=True) |
||||
self.do33d = self._make_dropout() |
||||
self.conv32d = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32d = self._make_dropout() |
||||
self.conv31d = Conv3x3(C3, C2, norm=True, act=True) |
||||
self.do31d = self._make_dropout() |
||||
|
||||
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) |
||||
|
||||
self.conv22d = Conv3x3(C3, C2, norm=True, act=True) |
||||
self.do22d = self._make_dropout() |
||||
self.conv21d = Conv3x3(C2, C1, norm=True, act=True) |
||||
self.do21d = self._make_dropout() |
||||
|
||||
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) |
||||
|
||||
self.conv12d = Conv3x3(C2, C1, norm=True, act=True) |
||||
self.do12d = self._make_dropout() |
||||
self.conv11d = Conv3x3(C1, num_classes) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
x = paddle.concat([t1, t2], axis=1) |
||||
|
||||
# Stage 1 |
||||
x11 = self.do11(self.conv11(x)) |
||||
x12 = self.do12(self.conv12(x11)) |
||||
x1p = self.pool1(x12) |
||||
|
||||
# Stage 2 |
||||
x21 = self.do21(self.conv21(x1p)) |
||||
x22 = self.do22(self.conv22(x21)) |
||||
x2p = self.pool2(x22) |
||||
|
||||
# Stage 3 |
||||
x31 = self.do31(self.conv31(x2p)) |
||||
x32 = self.do32(self.conv32(x31)) |
||||
x33 = self.do33(self.conv33(x32)) |
||||
x3p = self.pool3(x33) |
||||
|
||||
# Stage 4 |
||||
x41 = self.do41(self.conv41(x3p)) |
||||
x42 = self.do42(self.conv42(x41)) |
||||
x43 = self.do43(self.conv43(x42)) |
||||
x4p = self.pool4(x43) |
||||
|
||||
# Stage 4d |
||||
x4d = self.upconv4(x4p) |
||||
pad4 = ( |
||||
0, |
||||
paddle.shape(x43)[3]-paddle.shape(x4d)[3], |
||||
0, |
||||
paddle.shape(x43)[2]-paddle.shape(x4d)[2] |
||||
) |
||||
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43], 1) |
||||
x43d = self.do43d(self.conv43d(x4d)) |
||||
x42d = self.do42d(self.conv42d(x43d)) |
||||
x41d = self.do41d(self.conv41d(x42d)) |
||||
|
||||
# Stage 3d |
||||
x3d = self.upconv3(x41d) |
||||
pad3 = ( |
||||
0, |
||||
paddle.shape(x33)[3]-paddle.shape(x3d)[3], |
||||
0, |
||||
paddle.shape(x33)[2]-paddle.shape(x3d)[2] |
||||
) |
||||
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33], 1) |
||||
x33d = self.do33d(self.conv33d(x3d)) |
||||
x32d = self.do32d(self.conv32d(x33d)) |
||||
x31d = self.do31d(self.conv31d(x32d)) |
||||
|
||||
# Stage 2d |
||||
x2d = self.upconv2(x31d) |
||||
pad2 = ( |
||||
0, |
||||
paddle.shape(x22)[3]-paddle.shape(x2d)[3], |
||||
0, |
||||
paddle.shape(x22)[2]-paddle.shape(x2d)[2] |
||||
) |
||||
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22], 1) |
||||
x22d = self.do22d(self.conv22d(x2d)) |
||||
x21d = self.do21d(self.conv21d(x22d)) |
||||
|
||||
# Stage 1d |
||||
x1d = self.upconv1(x21d) |
||||
pad1 = ( |
||||
0, |
||||
paddle.shape(x12)[3]-paddle.shape(x1d)[3], |
||||
0, |
||||
paddle.shape(x12)[2]-paddle.shape(x1d)[2] |
||||
) |
||||
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12], 1) |
||||
x12d = self.do12d(self.conv12d(x1d)) |
||||
x11d = self.conv11d(x12d) |
||||
|
||||
return x11d, |
||||
|
||||
def init_weight(self): |
||||
for sublayer in self.sublayers(): |
||||
if isinstance(sublayer, nn.Conv2D): |
||||
normal_init(sublayer.weight, std=0.001) |
||||
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)): |
||||
constant_init(sublayer.weight, value=1.0) |
||||
constant_init(sublayer.bias, value=0.0) |
||||
|
||||
def _make_dropout(self): |
||||
if self.use_dropout: |
||||
return nn.Dropout2D(p=0.2) |
||||
else: |
||||
return Identity() |
@ -0,0 +1,224 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity |
||||
from .param_init import normal_init, constant_init |
||||
|
||||
|
||||
class UNetSiamConc(nn.Layer): |
||||
""" |
||||
The FC-Siam-conc implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection" |
||||
(https://arxiv.org/abs/1810.08462). |
||||
|
||||
Args: |
||||
in_channels (int): The number of bands of the input images. |
||||
num_classes (int): The number of target classes. |
||||
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained |
||||
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
in_channels, |
||||
num_classes, |
||||
use_dropout=False |
||||
): |
||||
super().__init__() |
||||
|
||||
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256 |
||||
|
||||
self.use_dropout = use_dropout |
||||
|
||||
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True) |
||||
self.do11 = self._make_dropout() |
||||
self.conv12 = Conv3x3(C1, C1, norm=True, act=True) |
||||
self.do12 = self._make_dropout() |
||||
self.pool1 = MaxPool2x2() |
||||
|
||||
self.conv21 = Conv3x3(C1, C2, norm=True, act=True) |
||||
self.do21 = self._make_dropout() |
||||
self.conv22 = Conv3x3(C2, C2, norm=True, act=True) |
||||
self.do22 = self._make_dropout() |
||||
self.pool2 = MaxPool2x2() |
||||
|
||||
self.conv31 = Conv3x3(C2, C3, norm=True, act=True) |
||||
self.do31 = self._make_dropout() |
||||
self.conv32 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32 = self._make_dropout() |
||||
self.conv33 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do33 = self._make_dropout() |
||||
self.pool3 = MaxPool2x2() |
||||
|
||||
self.conv41 = Conv3x3(C3, C4, norm=True, act=True) |
||||
self.do41 = self._make_dropout() |
||||
self.conv42 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42 = self._make_dropout() |
||||
self.conv43 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do43 = self._make_dropout() |
||||
self.pool4 = MaxPool2x2() |
||||
|
||||
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) |
||||
|
||||
self.conv43d = Conv3x3(C5+C4, C4, norm=True, act=True) |
||||
self.do43d = self._make_dropout() |
||||
self.conv42d = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42d = self._make_dropout() |
||||
self.conv41d = Conv3x3(C4, C3, norm=True, act=True) |
||||
self.do41d = self._make_dropout() |
||||
|
||||
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) |
||||
|
||||
self.conv33d = Conv3x3(C4+C3, C3, norm=True, act=True) |
||||
self.do33d = self._make_dropout() |
||||
self.conv32d = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32d = self._make_dropout() |
||||
self.conv31d = Conv3x3(C3, C2, norm=True, act=True) |
||||
self.do31d = self._make_dropout() |
||||
|
||||
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) |
||||
|
||||
self.conv22d = Conv3x3(C3+C2, C2, norm=True, act=True) |
||||
self.do22d = self._make_dropout() |
||||
self.conv21d = Conv3x3(C2, C1, norm=True, act=True) |
||||
self.do21d = self._make_dropout() |
||||
|
||||
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) |
||||
|
||||
self.conv12d = Conv3x3(C2+C1, C1, norm=True, act=True) |
||||
self.do12d = self._make_dropout() |
||||
self.conv11d = Conv3x3(C1, num_classes) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
# Encode t1 |
||||
# Stage 1 |
||||
x11 = self.do11(self.conv11(t1)) |
||||
x12_1 = self.do12(self.conv12(x11)) |
||||
x1p = self.pool1(x12_1) |
||||
|
||||
# Stage 2 |
||||
x21 = self.do21(self.conv21(x1p)) |
||||
x22_1 = self.do22(self.conv22(x21)) |
||||
x2p = self.pool2(x22_1) |
||||
|
||||
# Stage 3 |
||||
x31 = self.do31(self.conv31(x2p)) |
||||
x32 = self.do32(self.conv32(x31)) |
||||
x33_1 = self.do33(self.conv33(x32)) |
||||
x3p = self.pool3(x33_1) |
||||
|
||||
# Stage 4 |
||||
x41 = self.do41(self.conv41(x3p)) |
||||
x42 = self.do42(self.conv42(x41)) |
||||
x43_1 = self.do43(self.conv43(x42)) |
||||
x4p = self.pool4(x43_1) |
||||
|
||||
# Encode t2 |
||||
# Stage 1 |
||||
x11 = self.do11(self.conv11(t2)) |
||||
x12_2 = self.do12(self.conv12(x11)) |
||||
x1p = self.pool1(x12_2) |
||||
|
||||
# Stage 2 |
||||
x21 = self.do21(self.conv21(x1p)) |
||||
x22_2 = self.do22(self.conv22(x21)) |
||||
x2p = self.pool2(x22_2) |
||||
|
||||
# Stage 3 |
||||
x31 = self.do31(self.conv31(x2p)) |
||||
x32 = self.do32(self.conv32(x31)) |
||||
x33_2 = self.do33(self.conv33(x32)) |
||||
x3p = self.pool3(x33_2) |
||||
|
||||
# Stage 4 |
||||
x41 = self.do41(self.conv41(x3p)) |
||||
x42 = self.do42(self.conv42(x41)) |
||||
x43_2 = self.do43(self.conv43(x42)) |
||||
x4p = self.pool4(x43_2) |
||||
|
||||
# Decode |
||||
# Stage 4d |
||||
x4d = self.upconv4(x4p) |
||||
pad4 = ( |
||||
0, |
||||
paddle.shape(x43_1)[3]-paddle.shape(x4d)[3], |
||||
0, |
||||
paddle.shape(x43_1)[2]-paddle.shape(x4d)[2] |
||||
) |
||||
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), x43_1, x43_2], 1) |
||||
x43d = self.do43d(self.conv43d(x4d)) |
||||
x42d = self.do42d(self.conv42d(x43d)) |
||||
x41d = self.do41d(self.conv41d(x42d)) |
||||
|
||||
# Stage 3d |
||||
x3d = self.upconv3(x41d) |
||||
pad3 = ( |
||||
0, |
||||
paddle.shape(x33_1)[3]-paddle.shape(x3d)[3], |
||||
0, |
||||
paddle.shape(x33_1)[2]-paddle.shape(x3d)[2] |
||||
) |
||||
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), x33_1, x33_2], 1) |
||||
x33d = self.do33d(self.conv33d(x3d)) |
||||
x32d = self.do32d(self.conv32d(x33d)) |
||||
x31d = self.do31d(self.conv31d(x32d)) |
||||
|
||||
# Stage 2d |
||||
x2d = self.upconv2(x31d) |
||||
pad2 = ( |
||||
0, |
||||
paddle.shape(x22_1)[3]-paddle.shape(x2d)[3], |
||||
0, |
||||
paddle.shape(x22_1)[2]-paddle.shape(x2d)[2] |
||||
) |
||||
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), x22_1, x22_2], 1) |
||||
x22d = self.do22d(self.conv22d(x2d)) |
||||
x21d = self.do21d(self.conv21d(x22d)) |
||||
|
||||
# Stage 1d |
||||
x1d = self.upconv1(x21d) |
||||
pad1 = ( |
||||
0, |
||||
paddle.shape(x12_1)[3]-paddle.shape(x1d)[3], |
||||
0, |
||||
paddle.shape(x12_1)[2]-paddle.shape(x1d)[2] |
||||
) |
||||
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), x12_1, x12_2], 1) |
||||
x12d = self.do12d(self.conv12d(x1d)) |
||||
x11d = self.conv11d(x12d) |
||||
|
||||
return x11d, |
||||
|
||||
def init_weight(self): |
||||
for sublayer in self.sublayers(): |
||||
if isinstance(sublayer, nn.Conv2D): |
||||
normal_init(sublayer.weight, std=0.001) |
||||
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)): |
||||
constant_init(sublayer.weight, value=1.0) |
||||
constant_init(sublayer.bias, value=0.0) |
||||
|
||||
def _make_dropout(self): |
||||
if self.use_dropout: |
||||
return nn.Dropout2D(p=0.2) |
||||
else: |
||||
return Identity() |
@ -0,0 +1,224 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
from .layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity |
||||
from .param_init import normal_init, constant_init |
||||
|
||||
|
||||
class UNetSiamDiff(nn.Layer): |
||||
""" |
||||
The FC-Siam-diff implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection" |
||||
(https://arxiv.org/abs/1810.08462). |
||||
|
||||
Args: |
||||
in_channels (int): The number of bands of the input images. |
||||
num_classes (int): The number of target classes. |
||||
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained |
||||
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
in_channels, |
||||
num_classes, |
||||
use_dropout=False |
||||
): |
||||
super().__init__() |
||||
|
||||
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256 |
||||
|
||||
self.use_dropout = use_dropout |
||||
|
||||
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True) |
||||
self.do11 = self._make_dropout() |
||||
self.conv12 = Conv3x3(C1, C1, norm=True, act=True) |
||||
self.do12 = self._make_dropout() |
||||
self.pool1 = MaxPool2x2() |
||||
|
||||
self.conv21 = Conv3x3(C1, C2, norm=True, act=True) |
||||
self.do21 = self._make_dropout() |
||||
self.conv22 = Conv3x3(C2, C2, norm=True, act=True) |
||||
self.do22 = self._make_dropout() |
||||
self.pool2 = MaxPool2x2() |
||||
|
||||
self.conv31 = Conv3x3(C2, C3, norm=True, act=True) |
||||
self.do31 = self._make_dropout() |
||||
self.conv32 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32 = self._make_dropout() |
||||
self.conv33 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do33 = self._make_dropout() |
||||
self.pool3 = MaxPool2x2() |
||||
|
||||
self.conv41 = Conv3x3(C3, C4, norm=True, act=True) |
||||
self.do41 = self._make_dropout() |
||||
self.conv42 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42 = self._make_dropout() |
||||
self.conv43 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do43 = self._make_dropout() |
||||
self.pool4 = MaxPool2x2() |
||||
|
||||
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) |
||||
|
||||
self.conv43d = Conv3x3(C5, C4, norm=True, act=True) |
||||
self.do43d = self._make_dropout() |
||||
self.conv42d = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42d = self._make_dropout() |
||||
self.conv41d = Conv3x3(C4, C3, norm=True, act=True) |
||||
self.do41d = self._make_dropout() |
||||
|
||||
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) |
||||
|
||||
self.conv33d = Conv3x3(C4, C3, norm=True, act=True) |
||||
self.do33d = self._make_dropout() |
||||
self.conv32d = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32d = self._make_dropout() |
||||
self.conv31d = Conv3x3(C3, C2, norm=True, act=True) |
||||
self.do31d = self._make_dropout() |
||||
|
||||
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) |
||||
|
||||
self.conv22d = Conv3x3(C3, C2, norm=True, act=True) |
||||
self.do22d = self._make_dropout() |
||||
self.conv21d = Conv3x3(C2, C1, norm=True, act=True) |
||||
self.do21d = self._make_dropout() |
||||
|
||||
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) |
||||
|
||||
self.conv12d = Conv3x3(C2, C1, norm=True, act=True) |
||||
self.do12d = self._make_dropout() |
||||
self.conv11d = Conv3x3(C1, num_classes) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
# Encode t1 |
||||
# Stage 1 |
||||
x11 = self.do11(self.conv11(t1)) |
||||
x12_1 = self.do12(self.conv12(x11)) |
||||
x1p = self.pool1(x12_1) |
||||
|
||||
# Stage 2 |
||||
x21 = self.do21(self.conv21(x1p)) |
||||
x22_1 = self.do22(self.conv22(x21)) |
||||
x2p = self.pool2(x22_1) |
||||
|
||||
# Stage 3 |
||||
x31 = self.do31(self.conv31(x2p)) |
||||
x32 = self.do32(self.conv32(x31)) |
||||
x33_1 = self.do33(self.conv33(x32)) |
||||
x3p = self.pool3(x33_1) |
||||
|
||||
# Stage 4 |
||||
x41 = self.do41(self.conv41(x3p)) |
||||
x42 = self.do42(self.conv42(x41)) |
||||
x43_1 = self.do43(self.conv43(x42)) |
||||
x4p = self.pool4(x43_1) |
||||
|
||||
# Encode t2 |
||||
# Stage 1 |
||||
x11 = self.do11(self.conv11(t2)) |
||||
x12_2 = self.do12(self.conv12(x11)) |
||||
x1p = self.pool1(x12_2) |
||||
|
||||
# Stage 2 |
||||
x21 = self.do21(self.conv21(x1p)) |
||||
x22_2 = self.do22(self.conv22(x21)) |
||||
x2p = self.pool2(x22_2) |
||||
|
||||
# Stage 3 |
||||
x31 = self.do31(self.conv31(x2p)) |
||||
x32 = self.do32(self.conv32(x31)) |
||||
x33_2 = self.do33(self.conv33(x32)) |
||||
x3p = self.pool3(x33_2) |
||||
|
||||
# Stage 4 |
||||
x41 = self.do41(self.conv41(x3p)) |
||||
x42 = self.do42(self.conv42(x41)) |
||||
x43_2 = self.do43(self.conv43(x42)) |
||||
x4p = self.pool4(x43_2) |
||||
|
||||
# Decode |
||||
# Stage 4d |
||||
x4d = self.upconv4(x4p) |
||||
pad4 = ( |
||||
0, |
||||
paddle.shape(x43_1)[3]-paddle.shape(x4d)[3], |
||||
0, |
||||
paddle.shape(x43_1)[2]-paddle.shape(x4d)[2] |
||||
) |
||||
x4d = paddle.concat([F.pad(x4d, pad=pad4, mode='replicate'), paddle.abs(x43_1-x43_2)], 1) |
||||
x43d = self.do43d(self.conv43d(x4d)) |
||||
x42d = self.do42d(self.conv42d(x43d)) |
||||
x41d = self.do41d(self.conv41d(x42d)) |
||||
|
||||
# Stage 3d |
||||
x3d = self.upconv3(x41d) |
||||
pad3 = ( |
||||
0, |
||||
paddle.shape(x33_1)[3]-paddle.shape(x3d)[3], |
||||
0, |
||||
paddle.shape(x33_1)[2]-paddle.shape(x3d)[2] |
||||
) |
||||
x3d = paddle.concat([F.pad(x3d, pad=pad3, mode='replicate'), paddle.abs(x33_1-x33_2)], 1) |
||||
x33d = self.do33d(self.conv33d(x3d)) |
||||
x32d = self.do32d(self.conv32d(x33d)) |
||||
x31d = self.do31d(self.conv31d(x32d)) |
||||
|
||||
# Stage 2d |
||||
x2d = self.upconv2(x31d) |
||||
pad2 = ( |
||||
0, |
||||
paddle.shape(x22_1)[3]-paddle.shape(x2d)[3], |
||||
0, |
||||
paddle.shape(x22_1)[2]-paddle.shape(x2d)[2] |
||||
) |
||||
x2d = paddle.concat([F.pad(x2d, pad=pad2, mode='replicate'), paddle.abs(x22_1-x22_2)], 1) |
||||
x22d = self.do22d(self.conv22d(x2d)) |
||||
x21d = self.do21d(self.conv21d(x22d)) |
||||
|
||||
# Stage 1d |
||||
x1d = self.upconv1(x21d) |
||||
pad1 = ( |
||||
0, |
||||
paddle.shape(x12_1)[3]-paddle.shape(x1d)[3], |
||||
0, |
||||
paddle.shape(x12_1)[2]-paddle.shape(x1d)[2] |
||||
) |
||||
x1d = paddle.concat([F.pad(x1d, pad=pad1, mode='replicate'), paddle.abs(x12_1-x12_2)], 1) |
||||
x12d = self.do12d(self.conv12d(x1d)) |
||||
x11d = self.conv11d(x12d) |
||||
|
||||
return x11d, |
||||
|
||||
def init_weight(self): |
||||
for sublayer in self.sublayers(): |
||||
if isinstance(sublayer, nn.Conv2D): |
||||
normal_init(sublayer.weight, std=0.001) |
||||
elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)): |
||||
constant_init(sublayer.weight, value=1.0) |
||||
constant_init(sublayer.bias, value=0.0) |
||||
|
||||
def _make_dropout(self): |
||||
if self.use_dropout: |
||||
return nn.Dropout2D(p=0.2) |
||||
else: |
||||
return Identity() |
Loading…
Reference in new issue