[Feature] Add DSIFN

own
Bobholamovic 3 years ago
parent 8fd01a1d23
commit d35cffa344
  1. 1
      paddlers/models/cd/models/__init__.py
  2. 202
      paddlers/models/cd/models/dsifn.py
  3. 27
      paddlers/tasks/changedetector.py

@ -19,3 +19,4 @@ from .unet_siamdiff import UNetSiamDiff
from .stanet import STANet from .stanet import STANet
from .bit import BIT from .bit import BIT
from .snunet import SNUNet from .snunet import SNUNet
from .dsifn import DSIFN

@ -0,0 +1,202 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg16
from .layers import Conv1x1, make_norm, ChannelAttention, SpatialAttention
class DSIFN(nn.Layer):
"""
The DSIFN implementation based on PaddlePaddle.
The original article refers to
C. Zhang, et al., "A deeply supervised image fusion network for change detection in high resolution bi-temporal remote
sensing images"
(https://www.sciencedirect.com/science/article/pii/S0924271620301532)
Note that in this implementation, there is a flexible number of target classes.
Args:
num_classes (int): The number of target classes.
use_dropout (bool, optional): A bool value that indicates whether to use dropout layers. When the model is trained
on a relatively small dataset, the dropout layers help prevent overfitting. Default: False.
"""
def __init__(self, num_classes, use_dropout=False):
super().__init__()
self.encoder1 = self.encoder2 = VGG16FeaturePicker()
self.sa1 = SpatialAttention()
self.sa2= SpatialAttention()
self.sa3 = SpatialAttention()
self.sa4 = SpatialAttention()
self.sa5 = SpatialAttention()
self.ca1 = ChannelAttention(in_ch=1024)
self.bn_ca1 = make_norm(1024)
self.o1_conv1 = conv2d_bn(1024, 512, use_dropout)
self.o1_conv2 = conv2d_bn(512, 512, use_dropout)
self.bn_sa1 = make_norm(512)
self.o1_conv3 = Conv1x1(512, num_classes)
self.trans_conv1 = nn.Conv2DTranspose(512, 512, kernel_size=2, stride=2)
self.ca2 = ChannelAttention(in_ch=1536)
self.bn_ca2 = make_norm(1536)
self.o2_conv1 = conv2d_bn(1536, 512, use_dropout)
self.o2_conv2 = conv2d_bn(512, 256, use_dropout)
self.o2_conv3 = conv2d_bn(256, 256, use_dropout)
self.bn_sa2 = make_norm(256)
self.o2_conv4 = Conv1x1(256, num_classes)
self.trans_conv2 = nn.Conv2DTranspose(256, 256, kernel_size=2, stride=2)
self.ca3 = ChannelAttention(in_ch=768)
self.o3_conv1 = conv2d_bn(768, 256, use_dropout)
self.o3_conv2 = conv2d_bn(256, 128, use_dropout)
self.o3_conv3 = conv2d_bn(128, 128, use_dropout)
self.bn_sa3 = make_norm(128)
self.o3_conv4 = Conv1x1(128, num_classes)
self.trans_conv3 = nn.Conv2DTranspose(128, 128, kernel_size=2, stride=2)
self.ca4 = ChannelAttention(in_ch=384)
self.o4_conv1 = conv2d_bn(384, 128, use_dropout)
self.o4_conv2 = conv2d_bn(128, 64, use_dropout)
self.o4_conv3 = conv2d_bn(64, 64, use_dropout)
self.bn_sa4 = make_norm(64)
self.o4_conv4 = Conv1x1(64, num_classes)
self.trans_conv4 = nn.Conv2DTranspose(64, 64, kernel_size=2, stride=2)
self.ca5 = ChannelAttention(in_ch=192)
self.o5_conv1 = conv2d_bn(192, 64, use_dropout)
self.o5_conv2 = conv2d_bn(64, 32, use_dropout)
self.o5_conv3 = conv2d_bn(32, 16, use_dropout)
self.bn_sa5 = make_norm(16)
self.o5_conv4 = Conv1x1(16, num_classes)
def forward(self, t1, t2):
# Extract bi-temporal features.
with paddle.no_grad():
self.encoder1.eval(), self.encoder2.eval()
t1_feats = self.encoder1(t1)
t2_feats = self.encoder2(t2)
t1_f_l3, t1_f_l8, t1_f_l15, t1_f_l22, t1_f_l29 = t1_feats
t2_f_l3, t2_f_l8, t2_f_l15, t2_f_l22, t2_f_l29,= t2_feats
# Multi-level decoding
x = paddle.concat([t1_f_l29, t2_f_l29], axis=1)
x = self.o1_conv1(x)
x = self.o1_conv2(x)
x = self.sa1(x) * x
x = self.bn_sa1(x)
out1 = F.interpolate(
self.o1_conv3(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv1(x)
x = paddle.concat([x, t1_f_l22, t2_f_l22], axis=1)
x = self.ca2(x)*x
x = self.o2_conv1(x)
x = self.o2_conv2(x)
x = self.o2_conv3(x)
x = self.sa2(x) *x
x = self.bn_sa2(x)
out2 = F.interpolate(
self.o2_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv2(x)
x = paddle.concat([x, t1_f_l15, t2_f_l15], axis=1)
x = self.ca3(x)*x
x = self.o3_conv1(x)
x = self.o3_conv2(x)
x = self.o3_conv3(x)
x = self.sa3(x) *x
x = self.bn_sa3(x)
out3 = F.interpolate(
self.o3_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv3(x)
x = paddle.concat([x, t1_f_l8, t2_f_l8], axis=1)
x = self.ca4(x)*x
x = self.o4_conv1(x)
x = self.o4_conv2(x)
x = self.o4_conv3(x)
x = self.sa4(x) *x
x = self.bn_sa4(x)
out4 = F.interpolate(
self.o4_conv4(x),
size=paddle.shape(t1)[2:],
mode='bilinear',
align_corners=True
)
x = self.trans_conv4(x)
x = paddle.concat([x, t1_f_l3, t2_f_l3], axis=1)
x = self.ca5(x)*x
x = self.o5_conv1(x)
x = self.o5_conv2(x)
x = self.o5_conv3(x)
x = self.sa5(x) *x
x = self.bn_sa5(x)
out5 = self.o5_conv4(x)
return out5, out4, out3, out2, out1
class VGG16FeaturePicker(nn.Layer):
def __init__(self, indices=(3,8,15,22,29)):
super().__init__()
features = list(vgg16(pretrained=True).features)[:30]
self.features = nn.LayerList(features)
self.features.eval()
self.indices = set(indices)
def forward(self, x):
picked_feats = []
for idx, model in enumerate(self.features):
x = model(x)
if idx in self.indices:
picked_feats.append(x)
return picked_feats
def conv2d_bn(in_ch, out_ch, with_dropout=True):
lst = [
nn.Conv2D(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
nn.PReLU(),
make_norm(out_ch),
]
if with_dropout:
lst.append(nn.Dropout(p=0.6))
return nn.Sequential(*lst)

@ -31,7 +31,7 @@ from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from paddlers.transforms import ImgDecoder, Resize from paddlers.transforms import ImgDecoder, Resize
import paddlers.models.cd as cd import paddlers.models.cd as cd
__all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT", "SNUNet"] __all__ = ["CDNet", "UNetEarlyFusion", "UNetSiamConc", "UNetSiamDiff", "STANet", "BIT", "SNUNet", "DSIFN"]
class BaseChangeDetector(BaseModel): class BaseChangeDetector(BaseModel):
@ -793,3 +793,28 @@ class SNUNet(BaseChangeDetector):
num_classes=num_classes, num_classes=num_classes,
use_mixed_loss=use_mixed_loss, use_mixed_loss=use_mixed_loss,
**params) **params)
class DSIFN(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=None,
use_dropout=False,
**params):
params.update({
'use_dropout': use_dropout
})
super(DSIFN, self).__init__(
model_name='DSIFN',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
# HACK: currently the only legal value of `use_mixed_loss` is None, in which case the loss specifications are
# constructed automatically.
assert use_mixed_loss is None
if use_mixed_loss is None:
self.losses = {
# XXX: make sure the shallow copy works correctly here.
'types': [paddleseg.models.CrossEntropyLoss()]*5,
'coef': [1.0]*5
}
Loading…
Cancel
Save