commit
6bffd14166
18 changed files with 958 additions and 48 deletions
@ -0,0 +1,478 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3 |
||||
|
||||
bn_mom = 1 - 0.0003 |
||||
|
||||
|
||||
class NLBlock(nn.Layer): |
||||
def __init__(self, in_channels): |
||||
super(NLBlock, self).__init__() |
||||
self.conv_v = BasicConv( |
||||
in_ch=in_channels, |
||||
out_ch=in_channels, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_channels, momentum=0.9)) |
||||
self.W = BasicConv( |
||||
in_ch=in_channels, |
||||
out_ch=in_channels, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_channels, momentum=0.9), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x): |
||||
batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] |
||||
value = self.conv_v(x) |
||||
value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]]) |
||||
value = value.transpose([0, 2, 1]) # B * (H*W) * value_channels |
||||
key = x.reshape([batch_size, c, h * w]) # B * key_channels * (H*W) |
||||
query = x.reshape([batch_size, c, h * w]) |
||||
query = query.transpose([0, 2, 1]) |
||||
|
||||
sim_map = paddle.matmul(query, key) # B * (H*W) * (H*W) |
||||
sim_map = (c**-.5) * sim_map # B * (H*W) * (H*W) |
||||
sim_map = nn.functional.softmax(sim_map, axis=-1) # B * (H*W) * (H*W) |
||||
|
||||
context = paddle.matmul(sim_map, value) |
||||
context = context.transpose([0, 2, 1]) |
||||
context = context.reshape([batch_size, c, *x.shape[2:]]) |
||||
context = self.W(context) |
||||
|
||||
return context |
||||
|
||||
|
||||
class NLFPN(nn.Layer): |
||||
""" Non-local feature parymid network""" |
||||
|
||||
def __init__(self, in_dim, reduction=True): |
||||
super(NLFPN, self).__init__() |
||||
if reduction: |
||||
self.reduction = BasicConv( |
||||
in_ch=in_dim, |
||||
out_ch=in_dim // 4, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim // 4, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.re_reduction = BasicConv( |
||||
in_ch=in_dim // 4, |
||||
out_ch=in_dim, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
in_dim = in_dim // 4 |
||||
else: |
||||
self.reduction = None |
||||
self.re_reduction = None |
||||
self.conv_e1 = BasicConv( |
||||
in_dim, |
||||
in_dim, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_e2 = BasicConv( |
||||
in_dim, |
||||
in_dim * 2, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim * 2, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_e3 = BasicConv( |
||||
in_dim * 2, |
||||
in_dim * 4, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim * 4, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_d1 = BasicConv( |
||||
in_dim, |
||||
in_dim, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_d2 = BasicConv( |
||||
in_dim * 2, |
||||
in_dim, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_d3 = BasicConv( |
||||
in_dim * 4, |
||||
in_dim * 2, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim * 2, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.nl3 = NLBlock(in_dim * 2) |
||||
self.nl2 = NLBlock(in_dim) |
||||
self.nl1 = NLBlock(in_dim) |
||||
|
||||
self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2) |
||||
self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2) |
||||
|
||||
def forward(self, x): |
||||
if self.reduction is not None: |
||||
x = self.reduction(x) |
||||
e1 = self.conv_e1(x) # C,H,W |
||||
e2 = self.conv_e2(self.downsample_x2(e1)) # 2C,H/2,W/2 |
||||
e3 = self.conv_e3(self.downsample_x2(e2)) # 4C,H/4,W/4 |
||||
|
||||
d3 = self.conv_d3(e3) # 2C,H/4,W/4 |
||||
nl = self.nl3(d3) |
||||
d3 = self.upsample_x2(paddle.multiply(d3, nl)) ##2C,H/2,W/2 |
||||
d2 = self.conv_d2(e2 + d3) # C,H/2,W/2 |
||||
nl = self.nl2(d2) |
||||
d2 = self.upsample_x2(paddle.multiply(d2, nl)) # C,H,W |
||||
d1 = self.conv_d1(e1 + d2) |
||||
nl = self.nl1(d1) |
||||
d1 = paddle.multiply(d1, nl) # C,H,W |
||||
if self.re_reduction is not None: |
||||
d1 = self.re_reduction(d1) |
||||
|
||||
return d1 |
||||
|
||||
|
||||
class Cat(nn.Layer): |
||||
def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False): |
||||
super(Cat, self).__init__() |
||||
self.do_upsample = upsample |
||||
self.upsample = nn.Upsample(scale_factor=2, mode="nearest") |
||||
self.conv2d = BasicConv( |
||||
in_chn_high + in_chn_low, |
||||
out_chn, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x, y): |
||||
if self.do_upsample: |
||||
x = self.upsample(x) |
||||
|
||||
x = paddle.concat((x, y), 1) |
||||
|
||||
return self.conv2d(x) |
||||
|
||||
|
||||
class DoubleConv(nn.Layer): |
||||
def __init__(self, in_chn, out_chn, stride=1, dilation=1): |
||||
super(DoubleConv, self).__init__() |
||||
self.conv = nn.Sequential( |
||||
nn.Conv2D( |
||||
in_chn, |
||||
out_chn, |
||||
kernel_size=3, |
||||
stride=stride, |
||||
dilation=dilation, |
||||
padding=dilation), |
||||
nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
nn.ReLU(), |
||||
nn.Conv2D( |
||||
out_chn, out_chn, kernel_size=3, stride=1, padding=1), |
||||
nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
nn.ReLU()) |
||||
|
||||
def forward(self, x): |
||||
x = self.conv(x) |
||||
return x |
||||
|
||||
|
||||
class SEModule(nn.Layer): |
||||
def __init__(self, channels, reduction_channels): |
||||
super(SEModule, self).__init__() |
||||
self.fc1 = nn.Conv2D( |
||||
channels, |
||||
reduction_channels, |
||||
kernel_size=1, |
||||
padding=0, |
||||
bias_attr=True) |
||||
self.ReLU = nn.ReLU() |
||||
self.fc2 = nn.Conv2D( |
||||
reduction_channels, |
||||
channels, |
||||
kernel_size=1, |
||||
padding=0, |
||||
bias_attr=True) |
||||
|
||||
def forward(self, x): |
||||
x_se = x.reshape( |
||||
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape( |
||||
[x.shape[0], x.shape[1], 1, 1]) |
||||
|
||||
x_se = self.fc1(x_se) |
||||
x_se = self.ReLU(x_se) |
||||
x_se = self.fc2(x_se) |
||||
return x * F.sigmoid(x_se) |
||||
|
||||
|
||||
class BasicBlock(nn.Layer): |
||||
expansion = 1 |
||||
|
||||
def __init__(self, |
||||
inplanes, |
||||
planes, |
||||
downsample=None, |
||||
use_se=False, |
||||
stride=1, |
||||
dilation=1): |
||||
super(BasicBlock, self).__init__() |
||||
first_planes = planes |
||||
outplanes = planes * self.expansion |
||||
|
||||
self.conv1 = DoubleConv(inplanes, first_planes) |
||||
self.conv2 = DoubleConv( |
||||
first_planes, outplanes, stride=stride, dilation=dilation) |
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None |
||||
self.downsample = MaxPool2x2() if downsample else None |
||||
self.ReLU = nn.ReLU() |
||||
|
||||
def forward(self, x): |
||||
out = self.conv1(x) |
||||
residual = out |
||||
out = self.conv2(out) |
||||
|
||||
if self.se is not None: |
||||
out = self.se(out) |
||||
|
||||
if self.downsample is not None: |
||||
residual = self.downsample(residual) |
||||
|
||||
out = out + residual |
||||
out = self.ReLU(out) |
||||
return out |
||||
|
||||
|
||||
class DenseCatAdd(nn.Layer): |
||||
def __init__(self, in_chn, out_chn): |
||||
super(DenseCatAdd, self).__init__() |
||||
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv_out = BasicConv( |
||||
in_chn, |
||||
out_chn, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x, y): |
||||
x1 = self.conv1(x) |
||||
x2 = self.conv2(x1) |
||||
x3 = self.conv3(x2 + x1) |
||||
|
||||
y1 = self.conv1(y) |
||||
y2 = self.conv2(y1) |
||||
y3 = self.conv3(y2 + y1) |
||||
|
||||
return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3) |
||||
|
||||
|
||||
class DenseCatDiff(nn.Layer): |
||||
def __init__(self, in_chn, out_chn): |
||||
super(DenseCatDiff, self).__init__() |
||||
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv_out = BasicConv( |
||||
in_ch=in_chn, |
||||
out_ch=out_chn, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x, y): |
||||
x1 = self.conv1(x) |
||||
x2 = self.conv2(x1) |
||||
x3 = self.conv3(x2 + x1) |
||||
|
||||
y1 = self.conv1(y) |
||||
y2 = self.conv2(y1) |
||||
y3 = self.conv3(y2 + y1) |
||||
out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3)) |
||||
return out |
||||
|
||||
|
||||
class DFModule(nn.Layer): |
||||
"""Dense connection-based feature fusion module""" |
||||
|
||||
def __init__(self, dim_in, dim_out, reduction=True): |
||||
super(DFModule, self).__init__() |
||||
if reduction: |
||||
self.reduction = Conv1x1( |
||||
dim_in, |
||||
dim_in // 2, |
||||
norm=nn.BatchNorm2D( |
||||
dim_in // 2, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
dim_in = dim_in // 2 |
||||
else: |
||||
self.reduction = None |
||||
self.cat1 = DenseCatAdd(dim_in, dim_out) |
||||
self.cat2 = DenseCatDiff(dim_in, dim_out) |
||||
self.conv1 = Conv3x3( |
||||
dim_out, |
||||
dim_out, |
||||
norm=nn.BatchNorm2D( |
||||
dim_out, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x1, x2): |
||||
if self.reduction is not None: |
||||
x1 = self.reduction(x1) |
||||
x2 = self.reduction(x2) |
||||
x_add = self.cat1(x1, x2) |
||||
x_diff = self.cat2(x1, x2) |
||||
y = self.conv1(x_diff) + x_add |
||||
return y |
||||
|
||||
|
||||
class FCCDN(nn.Layer): |
||||
""" |
||||
The FCCDN implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" |
||||
(https://arxiv.org/pdf/2105.10860.pdf). |
||||
|
||||
Args: |
||||
in_channels (int): Number of input channels. Default: 3. |
||||
num_classes (int): Number of target classes. Default: 2. |
||||
os (int): Number of output stride. Default: 16. |
||||
use_se (bool): Whether to use SEModule. Default: True. |
||||
""" |
||||
|
||||
def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True): |
||||
super(FCCDN, self).__init__() |
||||
if os >= 16: |
||||
dilation_list = [1, 1, 1, 1] |
||||
stride_list = [2, 2, 2, 2] |
||||
pool_list = [True, True, True, True] |
||||
elif os == 8: |
||||
dilation_list = [2, 1, 1, 1] |
||||
stride_list = [1, 2, 2, 2] |
||||
pool_list = [False, True, True, True] |
||||
else: |
||||
dilation_list = [2, 2, 1, 1] |
||||
stride_list = [1, 1, 2, 2] |
||||
pool_list = [False, False, True, True] |
||||
se_list = [use_se, use_se, use_se, use_se] |
||||
channel_list = [256, 128, 64, 32] |
||||
# Encoder |
||||
self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3], |
||||
se_list[3], stride_list[3], dilation_list[3]) |
||||
self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2], |
||||
se_list[2], stride_list[2], dilation_list[2]) |
||||
self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1], |
||||
se_list[1], stride_list[1], dilation_list[1]) |
||||
self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0], |
||||
se_list[0], stride_list[0], dilation_list[0]) |
||||
|
||||
# Center |
||||
self.center = NLFPN(channel_list[0], True) |
||||
|
||||
# Decoder |
||||
self.decoder3 = Cat(channel_list[0], |
||||
channel_list[1], |
||||
channel_list[1], |
||||
upsample=pool_list[0]) |
||||
self.decoder2 = Cat(channel_list[1], |
||||
channel_list[2], |
||||
channel_list[2], |
||||
upsample=pool_list[1]) |
||||
self.decoder1 = Cat(channel_list[2], |
||||
channel_list[3], |
||||
channel_list[3], |
||||
upsample=pool_list[2]) |
||||
|
||||
self.df1 = DFModule(channel_list[3], channel_list[3], True) |
||||
self.df2 = DFModule(channel_list[2], channel_list[2], True) |
||||
self.df3 = DFModule(channel_list[1], channel_list[1], True) |
||||
self.df4 = DFModule(channel_list[0], channel_list[0], True) |
||||
|
||||
self.catc3 = Cat(channel_list[0], |
||||
channel_list[1], |
||||
channel_list[1], |
||||
upsample=pool_list[0]) |
||||
self.catc2 = Cat(channel_list[1], |
||||
channel_list[2], |
||||
channel_list[2], |
||||
upsample=pool_list[1]) |
||||
self.catc1 = Cat(channel_list[2], |
||||
channel_list[3], |
||||
channel_list[3], |
||||
upsample=pool_list[2]) |
||||
|
||||
self.upsample_x2 = nn.Sequential( |
||||
nn.Conv2D( |
||||
channel_list[3], 8, kernel_size=3, stride=1, padding=1), |
||||
nn.BatchNorm2D( |
||||
8, momentum=bn_mom), |
||||
nn.ReLU(), |
||||
nn.UpsamplingBilinear2D(scale_factor=2)) |
||||
|
||||
self.conv_out = nn.Conv2D( |
||||
8, num_classes, kernel_size=3, stride=1, padding=1) |
||||
self.conv_out_class = nn.Conv2D( |
||||
channel_list[3], 1, kernel_size=1, stride=1, padding=0) |
||||
|
||||
def forward(self, t1, t2): |
||||
e1_1 = self.block1(t1) |
||||
e2_1 = self.block2(e1_1) |
||||
e3_1 = self.block3(e2_1) |
||||
y1 = self.block4(e3_1) |
||||
|
||||
e1_2 = self.block1(t2) |
||||
e2_2 = self.block2(e1_2) |
||||
e3_2 = self.block3(e2_2) |
||||
y2 = self.block4(e3_2) |
||||
|
||||
y1 = self.center(y1) |
||||
y2 = self.center(y2) |
||||
c = self.df4(y1, y2) |
||||
|
||||
y1 = self.decoder3(y1, e3_1) |
||||
y2 = self.decoder3(y2, e3_2) |
||||
c = self.catc3(c, self.df3(y1, y2)) |
||||
|
||||
y1 = self.decoder2(y1, e2_1) |
||||
y2 = self.decoder2(y2, e2_2) |
||||
c = self.catc2(c, self.df2(y1, y2)) |
||||
|
||||
y1 = self.decoder1(y1, e1_1) |
||||
y2 = self.decoder1(y2, e1_2) |
||||
|
||||
c = self.catc1(c, self.df1(y1, y2)) |
||||
y = self.conv_out(self.upsample_x2(c)) |
||||
|
||||
if self.training: |
||||
y1 = self.conv_out_class(y1) |
||||
y2 = self.conv_out_class(y2) |
||||
return [y, [y1, y2]] |
||||
else: |
||||
return [y] |
@ -0,0 +1,15 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from .fccdn_loss import fccdn_ssl_loss |
@ -0,0 +1,170 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
class DiceLoss(nn.Layer): |
||||
def __init__(self, batch=True): |
||||
super(DiceLoss, self).__init__() |
||||
self.batch = batch |
||||
|
||||
def soft_dice_coeff(self, y_pred, y_true): |
||||
smooth = 0.00001 |
||||
if self.batch: |
||||
i = paddle.sum(y_true) |
||||
j = paddle.sum(y_pred) |
||||
intersection = paddle.sum(y_true * y_pred) |
||||
else: |
||||
i = y_true.sum(1).sum(1).sum(1) |
||||
j = y_pred.sum(1).sum(1).sum(1) |
||||
intersection = (y_true * y_pred).sum(1).sum(1).sum(1) |
||||
score = (2. * intersection + smooth) / (i + j + smooth) |
||||
return score.mean() |
||||
|
||||
def soft_dice_loss(self, y_pred, y_true): |
||||
loss = 1 - self.soft_dice_coeff(y_pred, y_true) |
||||
return loss |
||||
|
||||
def forward(self, y_pred, y_true): |
||||
return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true) |
||||
|
||||
|
||||
class MultiClassDiceLoss(nn.Layer): |
||||
def __init__( |
||||
self, |
||||
weight, |
||||
batch=True, |
||||
ignore_index=-1, |
||||
do_softmax=False, |
||||
**kwargs, ): |
||||
super(MultiClassDiceLoss, self).__init__() |
||||
self.ignore_index = ignore_index |
||||
self.weight = weight |
||||
self.do_softmax = do_softmax |
||||
self.binary_diceloss = DiceLoss(batch) |
||||
|
||||
def forward(self, y_pred, y_true): |
||||
if self.do_softmax: |
||||
y_pred = paddle.nn.functional.softmax(y_pred, axis=1) |
||||
y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2) |
||||
total_loss = 0.0 |
||||
tmp_i = 0.0 |
||||
for i in range(y_pred.shape[1]): |
||||
if i != self.ignore_index: |
||||
diceloss = self.binary_diceloss(y_pred[:, i, :, :], |
||||
y_true[:, i, :, :]) |
||||
total_loss += paddle.multiply(diceloss, self.weight[i]) |
||||
tmp_i += 1.0 |
||||
return total_loss / tmp_i |
||||
|
||||
|
||||
class DiceBCELoss(nn.Layer): |
||||
"""Binary change detection task loss""" |
||||
|
||||
def __init__(self): |
||||
super(DiceBCELoss, self).__init__() |
||||
self.bce_loss = nn.BCELoss() |
||||
self.binnary_dice = DiceLoss() |
||||
|
||||
def forward(self, scores, labels, do_sigmoid=True): |
||||
if len(scores.shape) > 3: |
||||
scores = scores.squeeze(1) |
||||
if len(labels.shape) > 3: |
||||
labels = labels.squeeze(1) |
||||
if do_sigmoid: |
||||
scores = paddle.nn.functional.sigmoid(scores.clone()) |
||||
diceloss = self.binnary_dice(scores, labels) |
||||
bceloss = self.bce_loss(scores, labels) |
||||
return diceloss + bceloss |
||||
|
||||
|
||||
class McDiceBCELoss(nn.Layer): |
||||
"""Multi-class change detection task loss""" |
||||
|
||||
def __init__(self, weight, do_sigmoid=True): |
||||
super(McDiceBCELoss, self).__init__() |
||||
self.ce_loss = nn.CrossEntropyLoss(weight) |
||||
self.dice = MultiClassDiceLoss(weight, do_sigmoid) |
||||
|
||||
def forward(self, scores, labels): |
||||
if len(scores.shape) < 4: |
||||
scores = scores.unsqueeze(1) |
||||
if len(labels.shape) < 4: |
||||
labels = labels.unsqueeze(1) |
||||
diceloss = self.dice(scores, labels) |
||||
bceloss = self.ce_loss(scores, labels) |
||||
return diceloss + bceloss |
||||
|
||||
|
||||
def fccdn_ssl_loss(logits_list, labels): |
||||
""" |
||||
Self-supervised learning loss for change detection. |
||||
|
||||
The original article refers to |
||||
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" |
||||
(https://arxiv.org/pdf/2105.10860.pdf). |
||||
|
||||
Args: |
||||
logits_list (list[paddle.Tensor]): Single-channel segmentation logit maps for each of the two temporal phases. |
||||
labels (paddle.Tensor): Binary change labels. |
||||
""" |
||||
|
||||
# Create loss |
||||
criterion_ssl = DiceBCELoss() |
||||
|
||||
# Get downsampled change map |
||||
h, w = logits_list[0].shape[-2], logits_list[0].shape[-1] |
||||
labels_downsample = F.interpolate(x=labels.unsqueeze(1), size=[h, w]) |
||||
labels_type = str(labels_downsample.dtype) |
||||
assert "int" in labels_type or "bool" in labels_type,\ |
||||
f"Expected dtype of labels to be int or bool, but got {labels_type}" |
||||
|
||||
# Seg map |
||||
out1 = paddle.nn.functional.sigmoid(logits_list[0]).clone() |
||||
out2 = paddle.nn.functional.sigmoid(logits_list[1]).clone() |
||||
out3 = out1.clone() |
||||
out4 = out2.clone() |
||||
|
||||
out1 = paddle.where(labels_downsample == 1, paddle.zeros_like(out1), out1) |
||||
out2 = paddle.where(labels_downsample == 1, paddle.zeros_like(out2), out2) |
||||
out3 = paddle.where(labels_downsample != 1, paddle.zeros_like(out3), out3) |
||||
out4 = paddle.where(labels_downsample != 1, paddle.zeros_like(out4), out4) |
||||
|
||||
pred_seg_pre_tmp1 = paddle.where(out1 <= 0.5, |
||||
paddle.zeros_like(out1), |
||||
paddle.ones_like(out1)) |
||||
pred_seg_post_tmp1 = paddle.where(out2 <= 0.5, |
||||
paddle.zeros_like(out2), |
||||
paddle.ones_like(out2)) |
||||
|
||||
pred_seg_pre_tmp2 = paddle.where(out3 <= 0.5, |
||||
paddle.zeros_like(out3), |
||||
paddle.ones_like(out3)) |
||||
pred_seg_post_tmp2 = paddle.where(out4 <= 0.5, |
||||
paddle.zeros_like(out4), |
||||
paddle.ones_like(out4)) |
||||
|
||||
# Seg loss |
||||
labels_downsample = labels_downsample.astype(paddle.float32) |
||||
loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False) |
||||
loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False) |
||||
loss_aux += 0.2 * criterion_ssl( |
||||
out3, labels_downsample - pred_seg_post_tmp2, False) |
||||
loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, |
||||
False) |
||||
|
||||
return loss_aux |
@ -0,0 +1,13 @@ |
||||
# Basic configurations of FCCDN |
||||
|
||||
_base_: ../_base_/airchange.yaml |
||||
|
||||
save_dir: ./test_tipc/output/cd/fccdn/ |
||||
|
||||
model: !Node |
||||
type: FCCDN |
||||
|
||||
learning_rate: 0.07 |
||||
lr_decay_power: 0.6 |
||||
log_interval_steps: 100 |
||||
save_interval_epochs: 3 |
@ -0,0 +1,53 @@ |
||||
===========================train_params=========================== |
||||
model_name:cd:fccdn |
||||
python:python |
||||
gpu_list:0 |
||||
use_gpu:null|null |
||||
--precision:null |
||||
--num_epochs:lite_train_lite_infer=15|lite_train_whole_infer=15|whole_train_whole_infer=15 |
||||
--save_dir:adaptive |
||||
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 |
||||
--model_path:null |
||||
train_model_name:best_model |
||||
train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt |
||||
null:null |
||||
## |
||||
trainer:norm |
||||
norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/fccdn/fccdn.yaml |
||||
pact_train:null |
||||
fpgm_train:null |
||||
distill_train:null |
||||
null:null |
||||
null:null |
||||
## |
||||
===========================eval_params=========================== |
||||
eval:null |
||||
null:null |
||||
## |
||||
===========================export_params=========================== |
||||
--save_dir:adaptive |
||||
--model_dir:adaptive |
||||
--fixed_input_shape:[1,3,256,256] |
||||
norm_export:deploy/export/export_model.py |
||||
quant_export:null |
||||
fpgm_export:null |
||||
distill_export:null |
||||
export1:null |
||||
export2:null |
||||
===========================infer_params=========================== |
||||
infer_model:null |
||||
infer_export:null |
||||
infer_quant:False |
||||
inference:test_tipc/infer.py |
||||
--device:cpu|gpu |
||||
--enable_mkldnn:True |
||||
--cpu_threads:6 |
||||
--batch_size:1 |
||||
--use_trt:False |
||||
--precision:fp32 |
||||
--model_dir:null |
||||
--file_list:null:null |
||||
--save_log_path:null |
||||
--benchmark:True |
||||
--model_name:fccdn |
||||
null:null |
@ -0,0 +1,41 @@ |
||||
#!/bin bash |
||||
|
||||
rm -rf /usr/local/python2.7.15/bin/python |
||||
rm -rf /usr/local/python2.7.15/bin/pip |
||||
ln -s /usr/local/bin/python3.7 /usr/local/python2.7.15/bin/python |
||||
ln -s /usr/local/bin/pip3.7 /usr/local/python2.7.15/bin/pip |
||||
export PYTHONPATH=`pwd` |
||||
|
||||
python -m pip install --upgrade pip --ignore-installed |
||||
# python -m pip install --upgrade numpy --ignore-installed |
||||
python -m pip uninstall paddlepaddle-gpu -y |
||||
if [[ ${branch} == 'develop' ]];then |
||||
echo "checkout develop !" |
||||
python -m pip install ${paddle_dev} --no-cache-dir |
||||
else |
||||
echo "checkout release !" |
||||
python -m pip install ${paddle_release} --no-cache-dir |
||||
fi |
||||
|
||||
echo -e '*****************paddle_version*****' |
||||
python -c 'import paddle;print(paddle.version.commit)' |
||||
echo -e '*****************paddleseg_version****' |
||||
git rev-parse HEAD |
||||
|
||||
pip install -r requirements.txt --ignore-installed |
||||
pip install -e . |
||||
pip install https://versaweb.dl.sourceforge.net/project/gdal-wheels-for-linux/GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl |
||||
|
||||
git clone https://github.com/LDOUBLEV/AutoLog |
||||
cd AutoLog |
||||
pip install -r requirements.txt |
||||
python setup.py bdist_wheel |
||||
pip install ./dist/auto_log*.whl |
||||
cd .. |
||||
|
||||
unset http_proxy https_proxy |
||||
|
||||
set -e |
||||
|
||||
cd tests/ |
||||
bash run_fast_tests.sh |
@ -0,0 +1,13 @@ |
||||
#!/usr/bin/env bash |
||||
|
||||
cd .. |
||||
|
||||
for config in $(ls test_tipc/configs/*/*/train_infer_python.txt); do |
||||
bash test_tipc/prepare.sh ${config} lite_train_lite_infer |
||||
bash test_tipc/test_train_inference_python.sh ${config} lite_train_lite_infer |
||||
task="$(basename $(dirname $(dirname ${config})))" |
||||
model="$(basename $(dirname ${config}))" |
||||
if grep -q 'failed' "test_tipc/output/${task}/${model}/lite_train_lite_infer/results_python.log"; then |
||||
exit 1 |
||||
fi |
||||
done |
@ -0,0 +1,94 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# 变化检测模型FCCDN训练示例脚本 |
||||
# 执行此脚本前,请确认已正确安装PaddleRS库 |
||||
|
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
# 数据集存放目录 |
||||
DATA_DIR = './data/airchange/' |
||||
# 训练集`file_list`文件路径 |
||||
TRAIN_FILE_LIST_PATH = './data/airchange/train.txt' |
||||
# 验证集`file_list`文件路径 |
||||
EVAL_FILE_LIST_PATH = './data/airchange/eval.txt' |
||||
# 实验目录,保存输出的模型权重和结果 |
||||
EXP_DIR = './output/fccdn/' |
||||
|
||||
# 下载和解压AirChange数据集 |
||||
pdrs.utils.download_and_decompress( |
||||
'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/') |
||||
|
||||
# 定义训练和验证时使用的数据变换(数据增强、预处理等) |
||||
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 |
||||
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md |
||||
train_transforms = T.Compose([ |
||||
# 读取影像 |
||||
T.DecodeImg(), |
||||
# 随机裁剪 |
||||
T.RandomCrop( |
||||
# 裁剪区域将被缩放到256x256 |
||||
crop_size=256, |
||||
# 裁剪区域的横纵比在0.5-2之间变动 |
||||
aspect_ratio=[0.5, 2.0], |
||||
# 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5 |
||||
scaling=[0.2, 1.0]), |
||||
# 以50%的概率实施随机水平翻转 |
||||
T.RandomHorizontalFlip(prob=0.5), |
||||
# 将数据归一化到[-1,1] |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ArrangeChangeDetector('train') |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.DecodeImg(), |
||||
# 验证阶段与训练阶段的数据归一化方式必须相同 |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ReloadMask(), |
||||
T.ArrangeChangeDetector('eval') |
||||
]) |
||||
|
||||
# 分别构建训练和验证所用的数据集 |
||||
train_dataset = pdrs.datasets.CDDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=TRAIN_FILE_LIST_PATH, |
||||
label_list=None, |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True, |
||||
with_seg_labels=False, |
||||
binarize_labels=True) |
||||
|
||||
eval_dataset = pdrs.datasets.CDDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=EVAL_FILE_LIST_PATH, |
||||
label_list=None, |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False, |
||||
with_seg_labels=False, |
||||
binarize_labels=True) |
||||
|
||||
# 使用默认参数构建FCCDN模型 |
||||
# 目前已支持的模型及模型输入参数请参考: |
||||
# https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py |
||||
model = pdrs.tasks.cd.FCCDN() |
||||
|
||||
# 执行模型训练 |
||||
model.train( |
||||
num_epochs=5, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=4, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=2, |
||||
# 每多少次迭代记录一次日志 |
||||
log_interval_steps=50, |
||||
save_dir=EXP_DIR, |
||||
# 是否使用early stopping策略,当精度不再改善时提前终止训练 |
||||
early_stop=False, |
||||
# 是否启用VisualDL日志功能 |
||||
use_vdl=True, |
||||
# 指定从某个检查点继续训练 |
||||
resume_checkpoint=None) |
Loading…
Reference in new issue