[Feature] Add STANet

own
Bobholamovic 3 years ago
parent 6932bf6e16
commit 91c64d923a
  1. 3
      paddlers/models/cd/models/__init__.py
  2. 13
      paddlers/models/cd/models/backbones/__init__.py
  3. 358
      paddlers/models/cd/models/backbones/resnet.py
  4. 3
      paddlers/models/cd/models/layers/blocks.py
  5. 297
      paddlers/models/cd/models/stanet.py
  6. 22
      paddlers/tasks/changedetector.py

@ -15,4 +15,5 @@
from .cdnet import CDNet
from .unet_ef import UNetEarlyFusion
from .unet_siamconc import UNetSiamConc
from .unet_siamdiff import UNetSiamDiff
from .unet_siamdiff import UNetSiamDiff
from .stanet import STANet

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,358 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/PaddlePaddle/Paddle/blob/release/2.2/python/paddle/vision/models/resnet.py
## Original head information
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
__all__ = []
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'cf548f46534aa3560945be4b95cd11c4'),
'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',
'8d2275cf8706028345f78ac0e1d31969'),
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'ca6f485ee1ab0492d38f323885b0ad80'),
'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',
'02f35f034ca3858e1e54d4036443c92d'),
'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
'7ad16a2f1e7333859ff986138630fd7a'),
}
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
self.conv1 = nn.Conv2D(
inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(nn.Layer):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BottleneckBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2D(
width,
width,
3,
padding=dilation,
stride=stride,
groups=groups,
dilation=dilation,
bias_attr=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2D(
width, planes * self.expansion, 1, bias_attr=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Layer):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): use pool before the last fc layer or not. Default: True.
Examples:
.. code-block:: python
from paddle.vision.models import ResNet
from paddle.vision.models.resnet import BottleneckBlock, BasicBlock
resnet50 = ResNet(BottleneckBlock, 50)
resnet18 = ResNet(BasicBlock, 18)
"""
def __init__(self, block, depth, num_classes=1000, with_pool=True, strides=(1,1,2,2,2), norm_layer=None):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.num_classes = num_classes
self.with_pool = with_pool
self._norm_layer = nn.BatchNorm2D if norm_layer is None else norm_layer
self.inplanes = 64
self.dilation = 1
self.conv1 = nn.Conv2D(
3,
self.inplanes,
kernel_size=7,
stride=strides[0],
padding=3,
bias_attr=False)
self.bn1 = self._norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[1])
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[2])
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[3])
self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[4])
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
if num_classes > 0:
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
1,
stride=stride,
bias_attr=False),
norm_layer(planes * block.expansion), )
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, 1, 64,
previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
def resnet18(pretrained=False, **kwargs):
"""ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet18
# build model
model = resnet18()
# build model and load imagenet pretrained weight
# model = resnet18(pretrained=True)
"""
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
"""ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet34
# build model
model = resnet34()
# build model and load imagenet pretrained weight
# model = resnet34(pretrained=True)
"""
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
"""ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet50
# build model
model = resnet50()
# build model and load imagenet pretrained weight
# model = resnet50(pretrained=True)
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
"""ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet101
# build model
model = resnet101()
# build model and load imagenet pretrained weight
# model = resnet101(pretrained=True)
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
"""ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
Examples:
.. code-block:: python
from paddle.vision.models import resnet152
# build model
model = resnet152()
# build model and load imagenet pretrained weight
# model = resnet152(pretrained=True)
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)

@ -19,7 +19,8 @@ __all__ = [
'BasicConv', 'Conv1x1', 'Conv3x3', 'Conv7x7',
'MaxPool2x2', 'MaxUnPool2x2',
'ConvTransposed3x3',
'Identity'
'Identity',
'get_norm_layer', 'get_act_layer'
]

@ -0,0 +1,297 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .backbones import resnet
from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
from .param_init import KaimingInitMixin
class STANet(nn.Layer):
"""
The STANet implementation based on PaddlePaddle.
The original article refers to
H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection"
(https://www.mdpi.com/2072-4292/12/10/1662)
Note that this implementation differs from the original work in two aspects:
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone.
2. A classification head is used in place of the original metric learning-based head to stablize the training process.
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'.
ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values
greater than 1, the input features will first be processed by an average pooling layer with the kernel size of
`ds_factor`, before being used to calculate the attention scores. Default: 1.
Raises:
ValueError: When `att_type` has an illeagal value (unsupported attention type).
"""
def __init__(
self,
in_channels,
num_classes,
att_type='BAM',
ds_factor=1
):
super().__init__()
WIDTH = 64
self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)
self.attend = build_sta_module(in_ch=WIDTH, att_type=att_type, ds=ds_factor)
self.conv_out = nn.Sequential(
Conv3x3(WIDTH, WIDTH, norm=True, act=True),
Conv3x3(WIDTH, num_classes)
)
self.init_weight()
def forward(self, t1, t2):
f1 = self.extract(t1)
f2 = self.extract(t2)
f1, f2 = self.attend(f1, f2)
y = paddle.abs(f1- f2)
y = F.interpolate(y, size=t1.shape[2:], mode='bilinear', align_corners=True)
pred = self.conv_out(y)
return pred,
def init_weight(self):
# Do nothing here as the encoder and decoder weights have already been initialized.
# Note however that currently self.attend and self.conv_out use the default initilization method.
pass
def build_feat_extractor(in_ch, width):
return nn.Sequential(
Backbone(in_ch, 'resnet18'),
Decoder(width)
)
def build_sta_module(in_ch, att_type, ds):
if att_type == 'BAM':
return Attention(BAM(in_ch, ds))
elif att_type == 'PAM':
return Attention(PAM(in_ch, ds))
else:
raise ValueError
class Backbone(nn.Layer, KaimingInitMixin):
def __init__(self, in_ch, arch, pretrained=True, strides=(2,1,2,2,2)):
super().__init__()
if arch == 'resnet18':
self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet34':
self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
elif arch == 'resnet50':
self.resnet = resnet.resnet50(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
else:
raise ValueError
self._trim_resnet()
if in_ch != 3:
self.resnet.conv1 = nn.Conv2D(
in_ch,
64,
kernel_size=7,
stride=strides[0],
padding=3,
bias_attr=False
)
if not pretrained:
self.init_weight()
def forward(self, x):
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x1 = self.resnet.layer1(x)
x2 = self.resnet.layer2(x1)
x3 = self.resnet.layer3(x2)
x4 = self.resnet.layer4(x3)
return x1, x2, x3, x4
def _trim_resnet(self):
self.resnet.avgpool = Identity()
self.resnet.fc = Identity()
class Decoder(nn.Layer, KaimingInitMixin):
def __init__(self, f_ch):
super().__init__()
self.dr1 = Conv1x1(64, 96, norm=True, act=True)
self.dr2 = Conv1x1(128, 96, norm=True, act=True)
self.dr3 = Conv1x1(256, 96, norm=True, act=True)
self.dr4 = Conv1x1(512, 96, norm=True, act=True)
self.conv_out = nn.Sequential(
Conv3x3(384, 256, norm=True, act=True),
nn.Dropout(0.5),
Conv1x1(256, f_ch, norm=True, act=True)
)
self.init_weight()
def forward(self, feats):
f1 = self.dr1(feats[0])
f2 = self.dr2(feats[1])
f3 = self.dr3(feats[2])
f4 = self.dr4(feats[3])
f2 = F.interpolate(f2, size=f1.shape[2:], mode='bilinear', align_corners=True)
f3 = F.interpolate(f3, size=f1.shape[2:], mode='bilinear', align_corners=True)
f4 = F.interpolate(f4, size=f1.shape[2:], mode='bilinear', align_corners=True)
x = paddle.concat([f1, f2, f3, f4], axis=1)
y = self.conv_out(x)
return y
class BAM(nn.Layer):
def __init__(self, in_ch, ds):
super().__init__()
self.ds = ds
self.pool = nn.AvgPool2D(self.ds)
self.val_ch = in_ch
self.key_ch = in_ch // 8
self.conv_q = Conv1x1(in_ch, self.key_ch)
self.conv_k = Conv1x1(in_ch, self.key_ch)
self.conv_v = Conv1x1(in_ch, self.val_ch)
self.softmax = nn.Softmax(axis=-1)
def forward(self, x):
x = x.flatten(-2)
x_rs = self.pool(x)
b, c, h, w = x_rs.shape
query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
key = self.conv_k(x_rs).reshape((b,-1,h*w))
energy = paddle.bmm(query, key)
energy = (self.key_ch**(-0.5)) * energy
attention = self.softmax(energy)
value = self.conv_v(x_rs).reshape((b,-1,w*h))
out = paddle.bmm(value, attention.transpose((0,2,1)))
out = out.reshape((b,c,h,w))
out = F.interpolate(out, scale_factor=self.ds)
out = out + x
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
class PAMBlock(nn.Layer):
def __init__(self, in_ch, scale=1, ds=1):
super().__init__()
self.scale = scale
self.ds = ds
self.pool = nn.AvgPool2D(self.ds)
self.val_ch = in_ch
self.key_ch = in_ch // 8
self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True)
self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True)
self.conv_v = Conv1x1(in_ch, self.val_ch)
def forward(self, x):
x_rs = self.pool(x)
# Get query, key, and value.
query = self.conv_q(x_rs)
key = self.conv_k(x_rs)
value = self.conv_v(x_rs)
# Split the whole image into subregions.
b, c, h, w = x_rs.shape
query = self._split_subregions(query)
key = self._split_subregions(key)
value = self._split_subregions(value)
# Perform subregion-wise attention.
out = self._attend(query, key, value)
# Stack subregions to reconstruct the whole image.
out = self._recons_whole(out, b, c, h, w)
out = F.interpolate(out, scale_factor=self.ds)
return out
def _attend(self, query, key, value):
energy = paddle.bmm(query.transpose((0,2,1)), key) # batch matrix multiplication
energy = (self.key_ch**(-0.5)) * energy
attention = F.softmax(energy, axis=-1)
out = paddle.bmm(value, attention.transpose((0,2,1)))
return out
def _split_subregions(self, x):
b, c, h, w = x.shape
assert h % self.scale == 0 and w % self.scale == 0
x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))
return x
def _recons_whole(self, x, b, c, h, w):
x = x.reshape((b, self.scale, self.scale, c, h//self.scale, w//self.scale))
x = x.transpose((0,3,1,4,2,5)).reshape((b, c, h, w))
return x
class PAM(nn.Layer):
def __init__(self, in_ch, ds, scales=(1,2,4,8)):
super().__init__()
self.stages = nn.LayerList([
PAMBlock(in_ch, scale=s, ds=ds)
for s in scales
])
self.conv_out = Conv1x1(in_ch*len(scales), in_ch, bias=False)
def forward(self, x):
x = x.flatten(-2)
res = [stage(x) for stage in self.stages]
out = self.conv_out(paddle.concat(res, axis=1))
return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
class Attention(nn.Layer):
def __init__(self, att):
super().__init__()
self.att = att
def forward(self, x1, x2):
x = paddle.stack([x1, x2], axis=-1)
y = self.att(x)
return y[...,0], y[...,1]

@ -31,7 +31,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from paddlers.transforms import ImgDecoder, Resize
import paddlers.models.cd as cd
__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff"]
__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet"]
class BaseChangeDetector(BaseModel):
@ -716,4 +716,24 @@ class UNetSiamDiff(BaseChangeDetector):
model_name='UNetSiamDiff',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class STANet(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
in_channels=3,
att_type='BAM',
ds_factor=1,
**params):
params.update({
'in_channels': in_channels,
'att_type': att_type,
'ds_factor': ds_factor
})
super(STANet, self).__init__(
model_name='STANet',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
Loading…
Cancel
Save