You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
310 lines
9.9 KiB
310 lines
9.9 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
|
|
from .backbones import resnet |
|
from .layers import Conv1x1, Conv3x3, get_bn_layer, Identity |
|
from .param_init import KaimingInitMixin |
|
|
|
|
|
class STANet(nn.Layer): |
|
""" |
|
The STANet implementation based on PaddlePaddle. |
|
|
|
The original article refers to |
|
H. Chen and Z. Shi, "A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection" |
|
(https://www.mdpi.com/2072-4292/12/10/1662). |
|
|
|
Note that this implementation differs from the original work in two aspects: |
|
1. We do not use multiple dilation rates in layer 4 of the ResNet backbone. |
|
2. A classification head is used in place of the original metric learning-based head to stablize the training process. |
|
|
|
Args: |
|
in_channels (int): The number of bands of the input images. |
|
num_classes (int): The number of target classes. |
|
att_type (str, optional): The attention module used in the model. Options are 'PAM' and 'BAM'. Default: 'BAM'. |
|
ds_factor (int, optional): The downsampling factor of the attention modules. When `ds_factor` is set to values |
|
greater than 1, the input features will first be processed by an average pooling layer with the kernel size of |
|
`ds_factor`, before being used to calculate the attention scores. Default: 1. |
|
|
|
Raises: |
|
ValueError: When `att_type` has an illeagal value (unsupported attention type). |
|
""" |
|
|
|
def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1): |
|
super(STANet, self).__init__() |
|
|
|
WIDTH = 64 |
|
|
|
self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH) |
|
self.attend = build_sta_module( |
|
in_ch=WIDTH, att_type=att_type, ds=ds_factor) |
|
self.conv_out = nn.Sequential( |
|
Conv3x3( |
|
WIDTH, WIDTH, norm=True, act=True), |
|
Conv3x3(WIDTH, num_classes)) |
|
|
|
self.init_weight() |
|
|
|
def forward(self, t1, t2): |
|
f1 = self.extract(t1) |
|
f2 = self.extract(t2) |
|
|
|
f1, f2 = self.attend(f1, f2) |
|
|
|
y = paddle.abs(f1 - f2) |
|
y = F.interpolate( |
|
y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True) |
|
|
|
pred = self.conv_out(y) |
|
return [pred] |
|
|
|
def init_weight(self): |
|
# Do nothing here as the encoder and decoder weights have already been initialized. |
|
# Note however that currently self.attend and self.conv_out use the default initilization method. |
|
pass |
|
|
|
|
|
def build_feat_extractor(in_ch, width): |
|
return nn.Sequential(Backbone(in_ch, 'resnet18'), Decoder(width)) |
|
|
|
|
|
def build_sta_module(in_ch, att_type, ds): |
|
if att_type == 'BAM': |
|
return Attention(BAM(in_ch, ds)) |
|
elif att_type == 'PAM': |
|
return Attention(PAM(in_ch, ds)) |
|
else: |
|
raise ValueError |
|
|
|
|
|
class Backbone(nn.Layer, KaimingInitMixin): |
|
def __init__(self, in_ch, arch, pretrained=True, strides=(2, 1, 2, 2, 2)): |
|
super(Backbone, self).__init__() |
|
|
|
if arch == 'resnet18': |
|
self.resnet = resnet.resnet18( |
|
pretrained=pretrained, |
|
strides=strides, |
|
norm_layer=get_bn_layer()) |
|
elif arch == 'resnet34': |
|
self.resnet = resnet.resnet34( |
|
pretrained=pretrained, |
|
strides=strides, |
|
norm_layer=get_bn_layer()) |
|
elif arch == 'resnet50': |
|
self.resnet = resnet.resnet50( |
|
pretrained=pretrained, |
|
strides=strides, |
|
norm_layer=get_bn_layer()) |
|
else: |
|
raise ValueError |
|
|
|
self._trim_resnet() |
|
|
|
if in_ch != 3: |
|
self.resnet.conv1 = nn.Conv2D( |
|
in_ch, |
|
64, |
|
kernel_size=7, |
|
stride=strides[0], |
|
padding=3, |
|
bias_attr=False) |
|
|
|
if not pretrained: |
|
self.init_weight() |
|
|
|
def forward(self, x): |
|
x = self.resnet.conv1(x) |
|
x = self.resnet.bn1(x) |
|
x = self.resnet.relu(x) |
|
x = self.resnet.maxpool(x) |
|
|
|
x1 = self.resnet.layer1(x) |
|
x2 = self.resnet.layer2(x1) |
|
x3 = self.resnet.layer3(x2) |
|
x4 = self.resnet.layer4(x3) |
|
|
|
return x1, x2, x3, x4 |
|
|
|
def _trim_resnet(self): |
|
self.resnet.avgpool = Identity() |
|
self.resnet.fc = Identity() |
|
|
|
|
|
class Decoder(nn.Layer, KaimingInitMixin): |
|
def __init__(self, f_ch): |
|
super(Decoder, self).__init__() |
|
self.dr1 = Conv1x1(64, 96, norm=True, act=True) |
|
self.dr2 = Conv1x1(128, 96, norm=True, act=True) |
|
self.dr3 = Conv1x1(256, 96, norm=True, act=True) |
|
self.dr4 = Conv1x1(512, 96, norm=True, act=True) |
|
self.conv_out = nn.Sequential( |
|
Conv3x3( |
|
384, 256, norm=True, act=True), |
|
nn.Dropout(0.5), |
|
Conv1x1( |
|
256, f_ch, norm=True, act=True)) |
|
|
|
self.init_weight() |
|
|
|
def forward(self, feats): |
|
f1 = self.dr1(feats[0]) |
|
f2 = self.dr2(feats[1]) |
|
f3 = self.dr3(feats[2]) |
|
f4 = self.dr4(feats[3]) |
|
|
|
f2 = F.interpolate( |
|
f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
|
f3 = F.interpolate( |
|
f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
|
f4 = F.interpolate( |
|
f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True) |
|
|
|
x = paddle.concat([f1, f2, f3, f4], axis=1) |
|
y = self.conv_out(x) |
|
|
|
return y |
|
|
|
|
|
class BAM(nn.Layer): |
|
def __init__(self, in_ch, ds): |
|
super(BAM, self).__init__() |
|
|
|
self.ds = ds |
|
self.pool = nn.AvgPool2D(self.ds) |
|
|
|
self.val_ch = in_ch |
|
self.key_ch = in_ch // 8 |
|
self.conv_q = Conv1x1(in_ch, self.key_ch) |
|
self.conv_k = Conv1x1(in_ch, self.key_ch) |
|
self.conv_v = Conv1x1(in_ch, self.val_ch) |
|
|
|
self.softmax = nn.Softmax(axis=-1) |
|
|
|
def forward(self, x): |
|
x = x.flatten(-2) |
|
x_rs = self.pool(x) |
|
|
|
b, c, h, w = paddle.shape(x_rs) |
|
query = self.conv_q(x_rs).reshape((b, -1, h * w)).transpose((0, 2, 1)) |
|
key = self.conv_k(x_rs).reshape((b, -1, h * w)) |
|
energy = paddle.bmm(query, key) |
|
energy = (self.key_ch**(-0.5)) * energy |
|
|
|
attention = self.softmax(energy) |
|
|
|
value = self.conv_v(x_rs).reshape((b, -1, w * h)) |
|
|
|
out = paddle.bmm(value, attention.transpose((0, 2, 1))) |
|
out = out.reshape((b, c, h, w)) |
|
|
|
out = F.interpolate(out, scale_factor=self.ds) |
|
out = out + x |
|
return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2)) |
|
|
|
|
|
class PAMBlock(nn.Layer): |
|
def __init__(self, in_ch, scale=1, ds=1): |
|
super(PAMBlock, self).__init__() |
|
|
|
self.scale = scale |
|
self.ds = ds |
|
self.pool = nn.AvgPool2D(self.ds) |
|
|
|
self.val_ch = in_ch |
|
self.key_ch = in_ch // 8 |
|
self.conv_q = Conv1x1(in_ch, self.key_ch, norm=True) |
|
self.conv_k = Conv1x1(in_ch, self.key_ch, norm=True) |
|
self.conv_v = Conv1x1(in_ch, self.val_ch) |
|
|
|
def forward(self, x): |
|
x_rs = self.pool(x) |
|
|
|
# Get query, key, and value. |
|
query = self.conv_q(x_rs) |
|
key = self.conv_k(x_rs) |
|
value = self.conv_v(x_rs) |
|
|
|
# Split the whole image into subregions. |
|
b, c, h, w = x_rs.shape |
|
|
|
query = self._split_subregions(query) |
|
key = self._split_subregions(key) |
|
value = self._split_subregions(value) |
|
|
|
# Perform subregion-wise attention. |
|
out = self._attend(query, key, value) |
|
|
|
# Stack subregions to reconstruct the whole image. |
|
out = self._recons_whole(out, b, c, h, w) |
|
out = F.interpolate(out, scale_factor=self.ds) |
|
return out |
|
|
|
def _attend(self, query, key, value): |
|
energy = paddle.bmm(query.transpose((0, 2, 1)), |
|
key) # batch matrix multiplication |
|
energy = (self.key_ch**(-0.5)) * energy |
|
attention = F.softmax(energy, axis=-1) |
|
out = paddle.bmm(value, attention.transpose((0, 2, 1))) |
|
return out |
|
|
|
def _split_subregions(self, x): |
|
b, c, h, w = x.shape |
|
assert h % self.scale == 0 and w % self.scale == 0 |
|
x = x.reshape( |
|
(b, c, self.scale, h // self.scale, self.scale, w // self.scale)) |
|
|
|
x = x.transpose((0, 2, 4, 1, 3, 5)) |
|
|
|
x = x.reshape((b * self.scale * self.scale, c, -1)) |
|
return x |
|
|
|
def _recons_whole(self, x, b, c, h, w): |
|
x = x.reshape( |
|
(b, self.scale, self.scale, c, h // self.scale, w // self.scale)) |
|
x = x.transpose((0, 3, 1, 4, 2, 5)).reshape((b, c, h, w)) |
|
return x |
|
|
|
|
|
class PAM(nn.Layer): |
|
def __init__(self, in_ch, ds, scales=(1, 2, 4, 8)): |
|
super(PAM, self).__init__() |
|
|
|
self.stages = nn.LayerList( |
|
[PAMBlock( |
|
in_ch, scale=s, ds=ds) for s in scales]) |
|
self.conv_out = Conv1x1(in_ch * len(scales), in_ch, bias=False) |
|
|
|
def forward(self, x): |
|
x = x.flatten(-2) |
|
res = [stage(x) for stage in self.stages] |
|
|
|
out = self.conv_out(paddle.concat(res, axis=1)) |
|
|
|
return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2)) |
|
|
|
|
|
class Attention(nn.Layer): |
|
def __init__(self, att): |
|
super(Attention, self).__init__() |
|
self.att = att |
|
|
|
def forward(self, x1, x2): |
|
x = paddle.stack([x1, x2], axis=-1) |
|
y = self.att(x) |
|
return y[..., 0], y[..., 1]
|
|
|