[论文复现赛] FCCDN (#23)

精度验收通过,代码符合规范,论文复现成功。

Co-authored-by: liuxtakeoff <763848861.qq.com>
own
liuxtakeoff 2 years ago committed by GitHub
parent cc0788fc57
commit bbbbd3c7c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      paddlers/rs_models/cd/__init__.py
  2. 478
      paddlers/rs_models/cd/fccdn.py
  3. 15
      paddlers/rs_models/cd/losses/__init__.py
  4. 170
      paddlers/rs_models/cd/losses/fccdn_loss.py
  5. 32
      paddlers/tasks/change_detector.py
  6. 13
      test_tipc/configs/cd/fccdn/fccdn.yaml
  7. 53
      test_tipc/configs/cd/fccdn/train_infer_python.txt
  8. 94
      tutorials/train/change_detection/fccdn.py

@ -23,3 +23,5 @@ from .fc_ef import FCEarlyFusion
from .fc_siam_conc import FCSiamConc
from .fc_siam_diff import FCSiamDiff
from .changeformer import ChangeFormer
from .fccdn import FCCDN
from .losses import fccdn_ssl_loss

@ -0,0 +1,478 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3
bn_mom = 1 - 0.0003
class NLBlock(nn.Layer):
def __init__(self, in_channels):
super(NLBlock, self).__init__()
self.conv_v = BasicConv(
in_ch=in_channels,
out_ch=in_channels,
kernel_size=3,
norm=nn.BatchNorm2D(
in_channels, momentum=0.9))
self.W = BasicConv(
in_ch=in_channels,
out_ch=in_channels,
kernel_size=3,
norm=nn.BatchNorm2D(
in_channels, momentum=0.9),
act=nn.ReLU())
def forward(self, x):
batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
value = self.conv_v(x)
value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]])
value = value.transpose([0, 2, 1]) # B * (H*W) * value_channels
key = x.reshape([batch_size, c, h * w]) # B * key_channels * (H*W)
query = x.reshape([batch_size, c, h * w])
query = query.transpose([0, 2, 1])
sim_map = paddle.matmul(query, key) # B * (H*W) * (H*W)
sim_map = (c**-.5) * sim_map # B * (H*W) * (H*W)
sim_map = nn.functional.softmax(sim_map, axis=-1) # B * (H*W) * (H*W)
context = paddle.matmul(sim_map, value)
context = context.transpose([0, 2, 1])
context = context.reshape([batch_size, c, *x.shape[2:]])
context = self.W(context)
return context
class NLFPN(nn.Layer):
""" Non-local feature parymid network"""
def __init__(self, in_dim, reduction=True):
super(NLFPN, self).__init__()
if reduction:
self.reduction = BasicConv(
in_ch=in_dim,
out_ch=in_dim // 4,
kernel_size=1,
norm=nn.BatchNorm2D(
in_dim // 4, momentum=bn_mom),
act=nn.ReLU())
self.re_reduction = BasicConv(
in_ch=in_dim // 4,
out_ch=in_dim,
kernel_size=1,
norm=nn.BatchNorm2D(
in_dim, momentum=bn_mom),
act=nn.ReLU())
in_dim = in_dim // 4
else:
self.reduction = None
self.re_reduction = None
self.conv_e1 = BasicConv(
in_dim,
in_dim,
kernel_size=3,
norm=nn.BatchNorm2D(
in_dim, momentum=bn_mom),
act=nn.ReLU())
self.conv_e2 = BasicConv(
in_dim,
in_dim * 2,
kernel_size=3,
norm=nn.BatchNorm2D(
in_dim * 2, momentum=bn_mom),
act=nn.ReLU())
self.conv_e3 = BasicConv(
in_dim * 2,
in_dim * 4,
kernel_size=3,
norm=nn.BatchNorm2D(
in_dim * 4, momentum=bn_mom),
act=nn.ReLU())
self.conv_d1 = BasicConv(
in_dim,
in_dim,
kernel_size=3,
norm=nn.BatchNorm2D(
in_dim, momentum=bn_mom),
act=nn.ReLU())
self.conv_d2 = BasicConv(
in_dim * 2,
in_dim,
kernel_size=3,
norm=nn.BatchNorm2D(
in_dim, momentum=bn_mom),
act=nn.ReLU())
self.conv_d3 = BasicConv(
in_dim * 4,
in_dim * 2,
kernel_size=3,
norm=nn.BatchNorm2D(
in_dim * 2, momentum=bn_mom),
act=nn.ReLU())
self.nl3 = NLBlock(in_dim * 2)
self.nl2 = NLBlock(in_dim)
self.nl1 = NLBlock(in_dim)
self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2)
self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2)
def forward(self, x):
if self.reduction is not None:
x = self.reduction(x)
e1 = self.conv_e1(x) # C,H,W
e2 = self.conv_e2(self.downsample_x2(e1)) # 2C,H/2,W/2
e3 = self.conv_e3(self.downsample_x2(e2)) # 4C,H/4,W/4
d3 = self.conv_d3(e3) # 2C,H/4,W/4
nl = self.nl3(d3)
d3 = self.upsample_x2(paddle.multiply(d3, nl)) ##2C,H/2,W/2
d2 = self.conv_d2(e2 + d3) # C,H/2,W/2
nl = self.nl2(d2)
d2 = self.upsample_x2(paddle.multiply(d2, nl)) # C,H,W
d1 = self.conv_d1(e1 + d2)
nl = self.nl1(d1)
d1 = paddle.multiply(d1, nl) # C,H,W
if self.re_reduction is not None:
d1 = self.re_reduction(d1)
return d1
class Cat(nn.Layer):
def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False):
super(Cat, self).__init__()
self.do_upsample = upsample
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
self.conv2d = BasicConv(
in_chn_high + in_chn_low,
out_chn,
kernel_size=1,
norm=nn.BatchNorm2D(
out_chn, momentum=bn_mom),
act=nn.ReLU())
def forward(self, x, y):
if self.do_upsample:
x = self.upsample(x)
x = paddle.concat((x, y), 1)
return self.conv2d(x)
class DoubleConv(nn.Layer):
def __init__(self, in_chn, out_chn, stride=1, dilation=1):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2D(
in_chn,
out_chn,
kernel_size=3,
stride=stride,
dilation=dilation,
padding=dilation),
nn.BatchNorm2D(
out_chn, momentum=bn_mom),
nn.ReLU(),
nn.Conv2D(
out_chn, out_chn, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2D(
out_chn, momentum=bn_mom),
nn.ReLU())
def forward(self, x):
x = self.conv(x)
return x
class SEModule(nn.Layer):
def __init__(self, channels, reduction_channels):
super(SEModule, self).__init__()
self.fc1 = nn.Conv2D(
channels,
reduction_channels,
kernel_size=1,
padding=0,
bias_attr=True)
self.ReLU = nn.ReLU()
self.fc2 = nn.Conv2D(
reduction_channels,
channels,
kernel_size=1,
padding=0,
bias_attr=True)
def forward(self, x):
x_se = x.reshape(
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape(
[x.shape[0], x.shape[1], 1, 1])
x_se = self.fc1(x_se)
x_se = self.ReLU(x_se)
x_se = self.fc2(x_se)
return x * F.sigmoid(x_se)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
downsample=None,
use_se=False,
stride=1,
dilation=1):
super(BasicBlock, self).__init__()
first_planes = planes
outplanes = planes * self.expansion
self.conv1 = DoubleConv(inplanes, first_planes)
self.conv2 = DoubleConv(
first_planes, outplanes, stride=stride, dilation=dilation)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.downsample = MaxPool2x2() if downsample else None
self.ReLU = nn.ReLU()
def forward(self, x):
out = self.conv1(x)
residual = out
out = self.conv2(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(residual)
out = out + residual
out = self.ReLU(out)
return out
class DenseCatAdd(nn.Layer):
def __init__(self, in_chn, out_chn):
super(DenseCatAdd, self).__init__()
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
self.conv_out = BasicConv(
in_chn,
out_chn,
kernel_size=1,
norm=nn.BatchNorm2D(
out_chn, momentum=bn_mom),
act=nn.ReLU())
def forward(self, x, y):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2 + x1)
y1 = self.conv1(y)
y2 = self.conv2(y1)
y3 = self.conv3(y2 + y1)
return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3)
class DenseCatDiff(nn.Layer):
def __init__(self, in_chn, out_chn):
super(DenseCatDiff, self).__init__()
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
self.conv_out = BasicConv(
in_ch=in_chn,
out_ch=out_chn,
kernel_size=1,
norm=nn.BatchNorm2D(
out_chn, momentum=bn_mom),
act=nn.ReLU())
def forward(self, x, y):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2 + x1)
y1 = self.conv1(y)
y2 = self.conv2(y1)
y3 = self.conv3(y2 + y1)
out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3))
return out
class DFModule(nn.Layer):
"""Dense connection-based feature fusion module"""
def __init__(self, dim_in, dim_out, reduction=True):
super(DFModule, self).__init__()
if reduction:
self.reduction = Conv1x1(
dim_in,
dim_in // 2,
norm=nn.BatchNorm2D(
dim_in // 2, momentum=bn_mom),
act=nn.ReLU())
dim_in = dim_in // 2
else:
self.reduction = None
self.cat1 = DenseCatAdd(dim_in, dim_out)
self.cat2 = DenseCatDiff(dim_in, dim_out)
self.conv1 = Conv3x3(
dim_out,
dim_out,
norm=nn.BatchNorm2D(
dim_out, momentum=bn_mom),
act=nn.ReLU())
def forward(self, x1, x2):
if self.reduction is not None:
x1 = self.reduction(x1)
x2 = self.reduction(x2)
x_add = self.cat1(x1, x2)
x_diff = self.cat2(x1, x2)
y = self.conv1(x_diff) + x_add
return y
class FCCDN(nn.Layer):
"""
The FCCDN implementation based on PaddlePaddle.
The original article refers to
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection"
(https://arxiv.org/pdf/2105.10860.pdf).
Args:
in_channels (int): Number of input channels. Default: 3.
num_classes (int): Number of target classes. Default: 2.
os (int): Number of output stride. Default: 16.
use_se (bool): Whether to use SEModule. Default: True.
"""
def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True):
super(FCCDN, self).__init__()
if os >= 16:
dilation_list = [1, 1, 1, 1]
stride_list = [2, 2, 2, 2]
pool_list = [True, True, True, True]
elif os == 8:
dilation_list = [2, 1, 1, 1]
stride_list = [1, 2, 2, 2]
pool_list = [False, True, True, True]
else:
dilation_list = [2, 2, 1, 1]
stride_list = [1, 1, 2, 2]
pool_list = [False, False, True, True]
se_list = [use_se, use_se, use_se, use_se]
channel_list = [256, 128, 64, 32]
# Encoder
self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3],
se_list[3], stride_list[3], dilation_list[3])
self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2],
se_list[2], stride_list[2], dilation_list[2])
self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1],
se_list[1], stride_list[1], dilation_list[1])
self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0],
se_list[0], stride_list[0], dilation_list[0])
# Center
self.center = NLFPN(channel_list[0], True)
# Decoder
self.decoder3 = Cat(channel_list[0],
channel_list[1],
channel_list[1],
upsample=pool_list[0])
self.decoder2 = Cat(channel_list[1],
channel_list[2],
channel_list[2],
upsample=pool_list[1])
self.decoder1 = Cat(channel_list[2],
channel_list[3],
channel_list[3],
upsample=pool_list[2])
self.df1 = DFModule(channel_list[3], channel_list[3], True)
self.df2 = DFModule(channel_list[2], channel_list[2], True)
self.df3 = DFModule(channel_list[1], channel_list[1], True)
self.df4 = DFModule(channel_list[0], channel_list[0], True)
self.catc3 = Cat(channel_list[0],
channel_list[1],
channel_list[1],
upsample=pool_list[0])
self.catc2 = Cat(channel_list[1],
channel_list[2],
channel_list[2],
upsample=pool_list[1])
self.catc1 = Cat(channel_list[2],
channel_list[3],
channel_list[3],
upsample=pool_list[2])
self.upsample_x2 = nn.Sequential(
nn.Conv2D(
channel_list[3], 8, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2D(
8, momentum=bn_mom),
nn.ReLU(),
nn.UpsamplingBilinear2D(scale_factor=2))
self.conv_out = nn.Conv2D(
8, num_classes, kernel_size=3, stride=1, padding=1)
self.conv_out_class = nn.Conv2D(
channel_list[3], 1, kernel_size=1, stride=1, padding=0)
def forward(self, t1, t2):
e1_1 = self.block1(t1)
e2_1 = self.block2(e1_1)
e3_1 = self.block3(e2_1)
y1 = self.block4(e3_1)
e1_2 = self.block1(t2)
e2_2 = self.block2(e1_2)
e3_2 = self.block3(e2_2)
y2 = self.block4(e3_2)
y1 = self.center(y1)
y2 = self.center(y2)
c = self.df4(y1, y2)
y1 = self.decoder3(y1, e3_1)
y2 = self.decoder3(y2, e3_2)
c = self.catc3(c, self.df3(y1, y2))
y1 = self.decoder2(y1, e2_1)
y2 = self.decoder2(y2, e2_2)
c = self.catc2(c, self.df2(y1, y2))
y1 = self.decoder1(y1, e1_1)
y2 = self.decoder1(y2, e1_2)
c = self.catc1(c, self.df1(y1, y2))
y = self.conv_out(self.upsample_x2(c))
if self.training:
y1 = self.conv_out_class(y1)
y2 = self.conv_out_class(y2)
return [y, [y1, y2]]
else:
return [y]

@ -0,0 +1,15 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fccdn_loss import fccdn_ssl_loss

@ -0,0 +1,170 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class DiceLoss(nn.Layer):
def __init__(self, batch=True):
super(DiceLoss, self).__init__()
self.batch = batch
def soft_dice_coeff(self, y_pred, y_true):
smooth = 0.00001
if self.batch:
i = paddle.sum(y_true)
j = paddle.sum(y_pred)
intersection = paddle.sum(y_true * y_pred)
else:
i = y_true.sum(1).sum(1).sum(1)
j = y_pred.sum(1).sum(1).sum(1)
intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
score = (2. * intersection + smooth) / (i + j + smooth)
return score.mean()
def soft_dice_loss(self, y_pred, y_true):
loss = 1 - self.soft_dice_coeff(y_pred, y_true)
return loss
def forward(self, y_pred, y_true):
return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true)
class MultiClassDiceLoss(nn.Layer):
def __init__(
self,
weight,
batch=True,
ignore_index=-1,
do_softmax=False,
**kwargs, ):
super(MultiClassDiceLoss, self).__init__()
self.ignore_index = ignore_index
self.weight = weight
self.do_softmax = do_softmax
self.binary_diceloss = DiceLoss(batch)
def forward(self, y_pred, y_true):
if self.do_softmax:
y_pred = paddle.nn.functional.softmax(y_pred, axis=1)
y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2)
total_loss = 0.0
tmp_i = 0.0
for i in range(y_pred.shape[1]):
if i != self.ignore_index:
diceloss = self.binary_diceloss(y_pred[:, i, :, :],
y_true[:, i, :, :])
total_loss += paddle.multiply(diceloss, self.weight[i])
tmp_i += 1.0
return total_loss / tmp_i
class DiceBCELoss(nn.Layer):
"""Binary change detection task loss"""
def __init__(self):
super(DiceBCELoss, self).__init__()
self.bce_loss = nn.BCELoss()
self.binnary_dice = DiceLoss()
def forward(self, scores, labels, do_sigmoid=True):
if len(scores.shape) > 3:
scores = scores.squeeze(1)
if len(labels.shape) > 3:
labels = labels.squeeze(1)
if do_sigmoid:
scores = paddle.nn.functional.sigmoid(scores.clone())
diceloss = self.binnary_dice(scores, labels)
bceloss = self.bce_loss(scores, labels)
return diceloss + bceloss
class McDiceBCELoss(nn.Layer):
"""Multi-class change detection task loss"""
def __init__(self, weight, do_sigmoid=True):
super(McDiceBCELoss, self).__init__()
self.ce_loss = nn.CrossEntropyLoss(weight)
self.dice = MultiClassDiceLoss(weight, do_sigmoid)
def forward(self, scores, labels):
if len(scores.shape) < 4:
scores = scores.unsqueeze(1)
if len(labels.shape) < 4:
labels = labels.unsqueeze(1)
diceloss = self.dice(scores, labels)
bceloss = self.ce_loss(scores, labels)
return diceloss + bceloss
def fccdn_ssl_loss(logits_list, labels):
"""
Self-supervised learning loss for change detection.
The original article refers to
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection"
(https://arxiv.org/pdf/2105.10860.pdf).
Args:
logits_list (list[paddle.Tensor]): Single-channel segmentation logit maps for each of the two temporal phases.
labels (paddle.Tensor): Binary change labels.
"""
# Create loss
criterion_ssl = DiceBCELoss()
# Get downsampled change map
h, w = logits_list[0].shape[-2], logits_list[0].shape[-1]
labels_downsample = F.interpolate(x=labels.unsqueeze(1), size=[h, w])
labels_type = str(labels_downsample.dtype)
assert "int" in labels_type or "bool" in labels_type,\
f"Expected dtype of labels to be int or bool, but got {labels_type}"
# Seg map
out1 = paddle.nn.functional.sigmoid(logits_list[0]).clone()
out2 = paddle.nn.functional.sigmoid(logits_list[1]).clone()
out3 = out1.clone()
out4 = out2.clone()
out1 = paddle.where(labels_downsample == 1, paddle.zeros_like(out1), out1)
out2 = paddle.where(labels_downsample == 1, paddle.zeros_like(out2), out2)
out3 = paddle.where(labels_downsample != 1, paddle.zeros_like(out3), out3)
out4 = paddle.where(labels_downsample != 1, paddle.zeros_like(out4), out4)
pred_seg_pre_tmp1 = paddle.where(out1 <= 0.5,
paddle.zeros_like(out1),
paddle.ones_like(out1))
pred_seg_post_tmp1 = paddle.where(out2 <= 0.5,
paddle.zeros_like(out2),
paddle.ones_like(out2))
pred_seg_pre_tmp2 = paddle.where(out3 <= 0.5,
paddle.zeros_like(out3),
paddle.ones_like(out3))
pred_seg_post_tmp2 = paddle.where(out4 <= 0.5,
paddle.zeros_like(out4),
paddle.ones_like(out4))
# Seg loss
labels_downsample = labels_downsample.astype(paddle.float32)
loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False)
loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False)
loss_aux += 0.2 * criterion_ssl(
out3, labels_downsample - pred_seg_post_tmp2, False)
loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
False)
return loss_aux

@ -37,7 +37,7 @@ from .utils import seg_metrics as metrics
__all__ = [
"CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
"SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer"
"SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer", "FCCDN"
]
@ -1055,7 +1055,7 @@ class ChangeStar(BaseChangeDetector):
if self.use_mixed_loss is False:
return {
# XXX: make sure the shallow copy works correctly here.
'types': [seglosses.CrossEntropyLoss()] * 4,
'types': [seg_losses.CrossEntropyLoss()] * 4,
'coef': [1.0] * 4
}
else:
@ -1082,3 +1082,31 @@ class ChangeFormer(BaseChangeDetector):
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class FCCDN(BaseChangeDetector):
def __init__(self,
in_channels=3,
num_classes=2,
use_mixed_loss=False,
losses=None,
**params):
params.update({'in_channels': in_channels})
super(FCCDN, self).__init__(
model_name='FCCDN',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
**params)
def default_loss(self):
if self.use_mixed_loss is False:
return {
'types':
[seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss],
'coef': [1.0, 1.0]
}
else:
raise ValueError(
f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
)

@ -0,0 +1,13 @@
# Basic configurations of FCCDN
_base_: ../_base_/airchange.yaml
save_dir: ./test_tipc/output/cd/fccdn/
model: !Node
type: FCCDN
learning_rate: 0.07
lr_decay_power: 0.6
log_interval_steps: 100
save_interval_epochs: 3

@ -0,0 +1,53 @@
===========================train_params===========================
model_name:cd:fccdn
python:python
gpu_list:0
use_gpu:null|null
--precision:null
--num_epochs:lite_train_lite_infer=15|lite_train_whole_infer=15|whole_train_whole_infer=15
--save_dir:adaptive
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
--model_path:null
train_model_name:best_model
train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
null:null
##
trainer:norm
norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/fccdn/fccdn.yaml
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================export_params===========================
--save_dir:adaptive
--model_dir:adaptive
--fixed_input_shape:[1,3,256,256]
norm_export:deploy/export/export_model.py
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:null
infer_export:null
infer_quant:False
inference:test_tipc/infer.py
--device:cpu|gpu
--enable_mkldnn:True
--cpu_threads:6
--batch_size:1
--use_trt:False
--precision:fp32
--model_dir:null
--file_list:null:null
--save_log_path:null
--benchmark:True
--model_name:fccdn
null:null

@ -0,0 +1,94 @@
#!/usr/bin/env python
# 变化检测模型FCCDN训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/airchange/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/airchange/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/airchange/eval.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/fccdn/'
# 下载和解压AirChange数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 随机裁剪
T.RandomCrop(
# 裁剪区域将被缩放到256x256
crop_size=256,
# 裁剪区域的横纵比在0.5-2之间变动
aspect_ratio=[0.5, 2.0],
# 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5
scaling=[0.2, 1.0]),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ArrangeChangeDetector('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ReloadMask(),
T.ArrangeChangeDetector('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.CDDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=None,
transforms=train_transforms,
num_workers=0,
shuffle=True,
with_seg_labels=False,
binarize_labels=True)
eval_dataset = pdrs.datasets.CDDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=None,
transforms=eval_transforms,
num_workers=0,
shuffle=False,
with_seg_labels=False,
binarize_labels=True)
# 使用默认参数构建FCCDN模型
# 目前已支持的模型及模型输入参数请参考:
# https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
model = pdrs.tasks.cd.FCCDN()
# 执行模型训练
model.train(
num_epochs=5,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=2,
# 每多少次迭代记录一次日志
log_interval_steps=50,
save_dir=EXP_DIR,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)
Loading…
Cancel
Save