diff --git a/.github/workflows/build_and_test.yaml b/.github/workflows/build.yaml
similarity index 66%
rename from .github/workflows/build_and_test.yaml
rename to .github/workflows/build.yaml
index 339634d..f9d3d3f 100644
--- a/.github/workflows/build_and_test.yaml
+++ b/.github/workflows/build.yaml
@@ -1,4 +1,4 @@
-name: build and test
+name: build
on:
push:
@@ -17,7 +17,7 @@ concurrency:
cancel-in-progress: true
jobs:
- build_and_test_cpu:
+ build_cpu:
runs-on: ${{ matrix.os }}
strategy:
matrix:
@@ -53,29 +53,5 @@ jobs:
python -m pip install -e .
- name: Install GDAL
run: python -m pip install ${{ matrix.gdal-whl-url }}
- - name: Run unittests
- run: |
- cd tests
- bash run_fast_tests.sh
- shell: bash
-
- build_and_test_cuda102:
- runs-on: ubuntu-18.04
- container:
- image: registry.baidubce.com/paddlepaddle/paddle:2.3.1-gpu-cuda10.2-cudnn7
- steps:
- - uses: actions/checkout@v3
- - name: Upgrade pip
- run: python3.7 -m pip install pip --upgrade --user
- - name: Install PaddleRS
- run: |
- python3.7 -m pip install -r requirements.txt
- python3.7 -m pip install -e .
- - name: Install GDAL
- run: python3.7 -m pip install https://versaweb.dl.sourceforge.net/project/gdal-wheels-for-linux/GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl
- # Do not run unittests, because there is NO GPU in the machine.
- # - name: Run unittests
- # run: |
- # cd tests
- # bash run_fast_tests.sh
- # shell: bash
\ No newline at end of file
+ - name: Test installation
+ run: python -c "import paddlers; print(paddlers.__version__)"
diff --git a/README.md b/README.md
index b67b4b1..7f38bec 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@
[![license](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
- [![build status](https://github.com/PaddlePaddle/PaddleRS/actions/workflows/build_and_test.yaml/badge.svg?branch=develop)](https://github.com/PaddlePaddle/PaddleRS/actions)
+ [![build status](https://github.com/PaddlePaddle/PaddleRS/actions/workflows/build.yaml/badge.svg?branch=develop)](https://github.com/PaddlePaddle/PaddleRS/actions)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-yellow.svg)
@@ -179,7 +179,7 @@ PaddleRS目录树中关键部分如下:
* 如果您发现任何PaddleRS存在的问题或是对PaddleRS有建议, 欢迎通过[GitHub Issues](https://github.com/PaddlePaddle/PaddleRS/issues)向我们提出。
* 欢迎加入PaddleRS微信群
-
+
## 使用教程
diff --git a/docs/apis/train.md b/docs/apis/train.md
index b5572db..5ac48f3 100644
--- a/docs/apis/train.md
+++ b/docs/apis/train.md
@@ -29,7 +29,7 @@
### 初始化`BaseSegmenter`子类对象
-- 一般支持设置`input_channel`、`num_classes`以及`use_mixed_loss`参数,分别表示输入通道数、输出类别数以及是否使用预置的混合损失。部分模型如`FarSeg`暂不支持对`input_channel`参数的设置。
+- 一般支持设置`in_channels`、`num_classes`以及`use_mixed_loss`参数,分别表示输入通道数、输出类别数以及是否使用预置的混合损失。部分模型如`FarSeg`暂不支持对`in_channels`参数的设置。
- `use_mixed_loss`参将在未来被弃用,因此不建议使用。
- 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/seg)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmentor.py)。
diff --git a/paddlers/rs_models/cd/__init__.py b/paddlers/rs_models/cd/__init__.py
index c3d75b5..274b2e9 100644
--- a/paddlers/rs_models/cd/__init__.py
+++ b/paddlers/rs_models/cd/__init__.py
@@ -23,3 +23,5 @@ from .fc_ef import FCEarlyFusion
from .fc_siam_conc import FCSiamConc
from .fc_siam_diff import FCSiamDiff
from .changeformer import ChangeFormer
+from .fccdn import FCCDN
+from .losses import fccdn_ssl_loss
diff --git a/paddlers/rs_models/cd/fccdn.py b/paddlers/rs_models/cd/fccdn.py
new file mode 100644
index 0000000..17d1673
--- /dev/null
+++ b/paddlers/rs_models/cd/fccdn.py
@@ -0,0 +1,478 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3
+
+bn_mom = 1 - 0.0003
+
+
+class NLBlock(nn.Layer):
+ def __init__(self, in_channels):
+ super(NLBlock, self).__init__()
+ self.conv_v = BasicConv(
+ in_ch=in_channels,
+ out_ch=in_channels,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_channels, momentum=0.9))
+ self.W = BasicConv(
+ in_ch=in_channels,
+ out_ch=in_channels,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_channels, momentum=0.9),
+ act=nn.ReLU())
+
+ def forward(self, x):
+ batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
+ value = self.conv_v(x)
+ value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]])
+ value = value.transpose([0, 2, 1]) # B * (H*W) * value_channels
+ key = x.reshape([batch_size, c, h * w]) # B * key_channels * (H*W)
+ query = x.reshape([batch_size, c, h * w])
+ query = query.transpose([0, 2, 1])
+
+ sim_map = paddle.matmul(query, key) # B * (H*W) * (H*W)
+ sim_map = (c**-.5) * sim_map # B * (H*W) * (H*W)
+ sim_map = nn.functional.softmax(sim_map, axis=-1) # B * (H*W) * (H*W)
+
+ context = paddle.matmul(sim_map, value)
+ context = context.transpose([0, 2, 1])
+ context = context.reshape([batch_size, c, *x.shape[2:]])
+ context = self.W(context)
+
+ return context
+
+
+class NLFPN(nn.Layer):
+ """ Non-local feature parymid network"""
+
+ def __init__(self, in_dim, reduction=True):
+ super(NLFPN, self).__init__()
+ if reduction:
+ self.reduction = BasicConv(
+ in_ch=in_dim,
+ out_ch=in_dim // 4,
+ kernel_size=1,
+ norm=nn.BatchNorm2D(
+ in_dim // 4, momentum=bn_mom),
+ act=nn.ReLU())
+ self.re_reduction = BasicConv(
+ in_ch=in_dim // 4,
+ out_ch=in_dim,
+ kernel_size=1,
+ norm=nn.BatchNorm2D(
+ in_dim, momentum=bn_mom),
+ act=nn.ReLU())
+ in_dim = in_dim // 4
+ else:
+ self.reduction = None
+ self.re_reduction = None
+ self.conv_e1 = BasicConv(
+ in_dim,
+ in_dim,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_dim, momentum=bn_mom),
+ act=nn.ReLU())
+ self.conv_e2 = BasicConv(
+ in_dim,
+ in_dim * 2,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_dim * 2, momentum=bn_mom),
+ act=nn.ReLU())
+ self.conv_e3 = BasicConv(
+ in_dim * 2,
+ in_dim * 4,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_dim * 4, momentum=bn_mom),
+ act=nn.ReLU())
+ self.conv_d1 = BasicConv(
+ in_dim,
+ in_dim,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_dim, momentum=bn_mom),
+ act=nn.ReLU())
+ self.conv_d2 = BasicConv(
+ in_dim * 2,
+ in_dim,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_dim, momentum=bn_mom),
+ act=nn.ReLU())
+ self.conv_d3 = BasicConv(
+ in_dim * 4,
+ in_dim * 2,
+ kernel_size=3,
+ norm=nn.BatchNorm2D(
+ in_dim * 2, momentum=bn_mom),
+ act=nn.ReLU())
+ self.nl3 = NLBlock(in_dim * 2)
+ self.nl2 = NLBlock(in_dim)
+ self.nl1 = NLBlock(in_dim)
+
+ self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2)
+ self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2)
+
+ def forward(self, x):
+ if self.reduction is not None:
+ x = self.reduction(x)
+ e1 = self.conv_e1(x) # C,H,W
+ e2 = self.conv_e2(self.downsample_x2(e1)) # 2C,H/2,W/2
+ e3 = self.conv_e3(self.downsample_x2(e2)) # 4C,H/4,W/4
+
+ d3 = self.conv_d3(e3) # 2C,H/4,W/4
+ nl = self.nl3(d3)
+ d3 = self.upsample_x2(paddle.multiply(d3, nl)) ##2C,H/2,W/2
+ d2 = self.conv_d2(e2 + d3) # C,H/2,W/2
+ nl = self.nl2(d2)
+ d2 = self.upsample_x2(paddle.multiply(d2, nl)) # C,H,W
+ d1 = self.conv_d1(e1 + d2)
+ nl = self.nl1(d1)
+ d1 = paddle.multiply(d1, nl) # C,H,W
+ if self.re_reduction is not None:
+ d1 = self.re_reduction(d1)
+
+ return d1
+
+
+class Cat(nn.Layer):
+ def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False):
+ super(Cat, self).__init__()
+ self.do_upsample = upsample
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.conv2d = BasicConv(
+ in_chn_high + in_chn_low,
+ out_chn,
+ kernel_size=1,
+ norm=nn.BatchNorm2D(
+ out_chn, momentum=bn_mom),
+ act=nn.ReLU())
+
+ def forward(self, x, y):
+ if self.do_upsample:
+ x = self.upsample(x)
+
+ x = paddle.concat((x, y), 1)
+
+ return self.conv2d(x)
+
+
+class DoubleConv(nn.Layer):
+ def __init__(self, in_chn, out_chn, stride=1, dilation=1):
+ super(DoubleConv, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ in_chn,
+ out_chn,
+ kernel_size=3,
+ stride=stride,
+ dilation=dilation,
+ padding=dilation),
+ nn.BatchNorm2D(
+ out_chn, momentum=bn_mom),
+ nn.ReLU(),
+ nn.Conv2D(
+ out_chn, out_chn, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2D(
+ out_chn, momentum=bn_mom),
+ nn.ReLU())
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class SEModule(nn.Layer):
+ def __init__(self, channels, reduction_channels):
+ super(SEModule, self).__init__()
+ self.fc1 = nn.Conv2D(
+ channels,
+ reduction_channels,
+ kernel_size=1,
+ padding=0,
+ bias_attr=True)
+ self.ReLU = nn.ReLU()
+ self.fc2 = nn.Conv2D(
+ reduction_channels,
+ channels,
+ kernel_size=1,
+ padding=0,
+ bias_attr=True)
+
+ def forward(self, x):
+ x_se = x.reshape(
+ [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape(
+ [x.shape[0], x.shape[1], 1, 1])
+
+ x_se = self.fc1(x_se)
+ x_se = self.ReLU(x_se)
+ x_se = self.fc2(x_se)
+ return x * F.sigmoid(x_se)
+
+
+class BasicBlock(nn.Layer):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ downsample=None,
+ use_se=False,
+ stride=1,
+ dilation=1):
+ super(BasicBlock, self).__init__()
+ first_planes = planes
+ outplanes = planes * self.expansion
+
+ self.conv1 = DoubleConv(inplanes, first_planes)
+ self.conv2 = DoubleConv(
+ first_planes, outplanes, stride=stride, dilation=dilation)
+ self.se = SEModule(outplanes, planes // 4) if use_se else None
+ self.downsample = MaxPool2x2() if downsample else None
+ self.ReLU = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv1(x)
+ residual = out
+ out = self.conv2(out)
+
+ if self.se is not None:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+
+ out = out + residual
+ out = self.ReLU(out)
+ return out
+
+
+class DenseCatAdd(nn.Layer):
+ def __init__(self, in_chn, out_chn):
+ super(DenseCatAdd, self).__init__()
+ self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+ self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+ self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+ self.conv_out = BasicConv(
+ in_chn,
+ out_chn,
+ kernel_size=1,
+ norm=nn.BatchNorm2D(
+ out_chn, momentum=bn_mom),
+ act=nn.ReLU())
+
+ def forward(self, x, y):
+ x1 = self.conv1(x)
+ x2 = self.conv2(x1)
+ x3 = self.conv3(x2 + x1)
+
+ y1 = self.conv1(y)
+ y2 = self.conv2(y1)
+ y3 = self.conv3(y2 + y1)
+
+ return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3)
+
+
+class DenseCatDiff(nn.Layer):
+ def __init__(self, in_chn, out_chn):
+ super(DenseCatDiff, self).__init__()
+ self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+ self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+ self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU())
+ self.conv_out = BasicConv(
+ in_ch=in_chn,
+ out_ch=out_chn,
+ kernel_size=1,
+ norm=nn.BatchNorm2D(
+ out_chn, momentum=bn_mom),
+ act=nn.ReLU())
+
+ def forward(self, x, y):
+ x1 = self.conv1(x)
+ x2 = self.conv2(x1)
+ x3 = self.conv3(x2 + x1)
+
+ y1 = self.conv1(y)
+ y2 = self.conv2(y1)
+ y3 = self.conv3(y2 + y1)
+ out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3))
+ return out
+
+
+class DFModule(nn.Layer):
+ """Dense connection-based feature fusion module"""
+
+ def __init__(self, dim_in, dim_out, reduction=True):
+ super(DFModule, self).__init__()
+ if reduction:
+ self.reduction = Conv1x1(
+ dim_in,
+ dim_in // 2,
+ norm=nn.BatchNorm2D(
+ dim_in // 2, momentum=bn_mom),
+ act=nn.ReLU())
+ dim_in = dim_in // 2
+ else:
+ self.reduction = None
+ self.cat1 = DenseCatAdd(dim_in, dim_out)
+ self.cat2 = DenseCatDiff(dim_in, dim_out)
+ self.conv1 = Conv3x3(
+ dim_out,
+ dim_out,
+ norm=nn.BatchNorm2D(
+ dim_out, momentum=bn_mom),
+ act=nn.ReLU())
+
+ def forward(self, x1, x2):
+ if self.reduction is not None:
+ x1 = self.reduction(x1)
+ x2 = self.reduction(x2)
+ x_add = self.cat1(x1, x2)
+ x_diff = self.cat2(x1, x2)
+ y = self.conv1(x_diff) + x_add
+ return y
+
+
+class FCCDN(nn.Layer):
+ """
+ The FCCDN implementation based on PaddlePaddle.
+
+ The original article refers to
+ Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection"
+ (https://arxiv.org/pdf/2105.10860.pdf).
+
+ Args:
+ in_channels (int): Number of input channels. Default: 3.
+ num_classes (int): Number of target classes. Default: 2.
+ os (int): Number of output stride. Default: 16.
+ use_se (bool): Whether to use SEModule. Default: True.
+ """
+
+ def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True):
+ super(FCCDN, self).__init__()
+ if os >= 16:
+ dilation_list = [1, 1, 1, 1]
+ stride_list = [2, 2, 2, 2]
+ pool_list = [True, True, True, True]
+ elif os == 8:
+ dilation_list = [2, 1, 1, 1]
+ stride_list = [1, 2, 2, 2]
+ pool_list = [False, True, True, True]
+ else:
+ dilation_list = [2, 2, 1, 1]
+ stride_list = [1, 1, 2, 2]
+ pool_list = [False, False, True, True]
+ se_list = [use_se, use_se, use_se, use_se]
+ channel_list = [256, 128, 64, 32]
+ # Encoder
+ self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3],
+ se_list[3], stride_list[3], dilation_list[3])
+ self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2],
+ se_list[2], stride_list[2], dilation_list[2])
+ self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1],
+ se_list[1], stride_list[1], dilation_list[1])
+ self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0],
+ se_list[0], stride_list[0], dilation_list[0])
+
+ # Center
+ self.center = NLFPN(channel_list[0], True)
+
+ # Decoder
+ self.decoder3 = Cat(channel_list[0],
+ channel_list[1],
+ channel_list[1],
+ upsample=pool_list[0])
+ self.decoder2 = Cat(channel_list[1],
+ channel_list[2],
+ channel_list[2],
+ upsample=pool_list[1])
+ self.decoder1 = Cat(channel_list[2],
+ channel_list[3],
+ channel_list[3],
+ upsample=pool_list[2])
+
+ self.df1 = DFModule(channel_list[3], channel_list[3], True)
+ self.df2 = DFModule(channel_list[2], channel_list[2], True)
+ self.df3 = DFModule(channel_list[1], channel_list[1], True)
+ self.df4 = DFModule(channel_list[0], channel_list[0], True)
+
+ self.catc3 = Cat(channel_list[0],
+ channel_list[1],
+ channel_list[1],
+ upsample=pool_list[0])
+ self.catc2 = Cat(channel_list[1],
+ channel_list[2],
+ channel_list[2],
+ upsample=pool_list[1])
+ self.catc1 = Cat(channel_list[2],
+ channel_list[3],
+ channel_list[3],
+ upsample=pool_list[2])
+
+ self.upsample_x2 = nn.Sequential(
+ nn.Conv2D(
+ channel_list[3], 8, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2D(
+ 8, momentum=bn_mom),
+ nn.ReLU(),
+ nn.UpsamplingBilinear2D(scale_factor=2))
+
+ self.conv_out = nn.Conv2D(
+ 8, num_classes, kernel_size=3, stride=1, padding=1)
+ self.conv_out_class = nn.Conv2D(
+ channel_list[3], 1, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, t1, t2):
+ e1_1 = self.block1(t1)
+ e2_1 = self.block2(e1_1)
+ e3_1 = self.block3(e2_1)
+ y1 = self.block4(e3_1)
+
+ e1_2 = self.block1(t2)
+ e2_2 = self.block2(e1_2)
+ e3_2 = self.block3(e2_2)
+ y2 = self.block4(e3_2)
+
+ y1 = self.center(y1)
+ y2 = self.center(y2)
+ c = self.df4(y1, y2)
+
+ y1 = self.decoder3(y1, e3_1)
+ y2 = self.decoder3(y2, e3_2)
+ c = self.catc3(c, self.df3(y1, y2))
+
+ y1 = self.decoder2(y1, e2_1)
+ y2 = self.decoder2(y2, e2_2)
+ c = self.catc2(c, self.df2(y1, y2))
+
+ y1 = self.decoder1(y1, e1_1)
+ y2 = self.decoder1(y2, e1_2)
+
+ c = self.catc1(c, self.df1(y1, y2))
+ y = self.conv_out(self.upsample_x2(c))
+
+ if self.training:
+ y1 = self.conv_out_class(y1)
+ y2 = self.conv_out_class(y2)
+ return [y, [y1, y2]]
+ else:
+ return [y]
diff --git a/paddlers/rs_models/cd/losses/__init__.py b/paddlers/rs_models/cd/losses/__init__.py
new file mode 100644
index 0000000..49465ff
--- /dev/null
+++ b/paddlers/rs_models/cd/losses/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .fccdn_loss import fccdn_ssl_loss
diff --git a/paddlers/rs_models/cd/losses/fccdn_loss.py b/paddlers/rs_models/cd/losses/fccdn_loss.py
new file mode 100644
index 0000000..49d2b4c
--- /dev/null
+++ b/paddlers/rs_models/cd/losses/fccdn_loss.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class DiceLoss(nn.Layer):
+ def __init__(self, batch=True):
+ super(DiceLoss, self).__init__()
+ self.batch = batch
+
+ def soft_dice_coeff(self, y_pred, y_true):
+ smooth = 0.00001
+ if self.batch:
+ i = paddle.sum(y_true)
+ j = paddle.sum(y_pred)
+ intersection = paddle.sum(y_true * y_pred)
+ else:
+ i = y_true.sum(1).sum(1).sum(1)
+ j = y_pred.sum(1).sum(1).sum(1)
+ intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
+ score = (2. * intersection + smooth) / (i + j + smooth)
+ return score.mean()
+
+ def soft_dice_loss(self, y_pred, y_true):
+ loss = 1 - self.soft_dice_coeff(y_pred, y_true)
+ return loss
+
+ def forward(self, y_pred, y_true):
+ return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true)
+
+
+class MultiClassDiceLoss(nn.Layer):
+ def __init__(
+ self,
+ weight,
+ batch=True,
+ ignore_index=-1,
+ do_softmax=False,
+ **kwargs, ):
+ super(MultiClassDiceLoss, self).__init__()
+ self.ignore_index = ignore_index
+ self.weight = weight
+ self.do_softmax = do_softmax
+ self.binary_diceloss = DiceLoss(batch)
+
+ def forward(self, y_pred, y_true):
+ if self.do_softmax:
+ y_pred = paddle.nn.functional.softmax(y_pred, axis=1)
+ y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2)
+ total_loss = 0.0
+ tmp_i = 0.0
+ for i in range(y_pred.shape[1]):
+ if i != self.ignore_index:
+ diceloss = self.binary_diceloss(y_pred[:, i, :, :],
+ y_true[:, i, :, :])
+ total_loss += paddle.multiply(diceloss, self.weight[i])
+ tmp_i += 1.0
+ return total_loss / tmp_i
+
+
+class DiceBCELoss(nn.Layer):
+ """Binary change detection task loss"""
+
+ def __init__(self):
+ super(DiceBCELoss, self).__init__()
+ self.bce_loss = nn.BCELoss()
+ self.binnary_dice = DiceLoss()
+
+ def forward(self, scores, labels, do_sigmoid=True):
+ if len(scores.shape) > 3:
+ scores = scores.squeeze(1)
+ if len(labels.shape) > 3:
+ labels = labels.squeeze(1)
+ if do_sigmoid:
+ scores = paddle.nn.functional.sigmoid(scores.clone())
+ diceloss = self.binnary_dice(scores, labels)
+ bceloss = self.bce_loss(scores, labels)
+ return diceloss + bceloss
+
+
+class McDiceBCELoss(nn.Layer):
+ """Multi-class change detection task loss"""
+
+ def __init__(self, weight, do_sigmoid=True):
+ super(McDiceBCELoss, self).__init__()
+ self.ce_loss = nn.CrossEntropyLoss(weight)
+ self.dice = MultiClassDiceLoss(weight, do_sigmoid)
+
+ def forward(self, scores, labels):
+ if len(scores.shape) < 4:
+ scores = scores.unsqueeze(1)
+ if len(labels.shape) < 4:
+ labels = labels.unsqueeze(1)
+ diceloss = self.dice(scores, labels)
+ bceloss = self.ce_loss(scores, labels)
+ return diceloss + bceloss
+
+
+def fccdn_ssl_loss(logits_list, labels):
+ """
+ Self-supervised learning loss for change detection.
+
+ The original article refers to
+ Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection"
+ (https://arxiv.org/pdf/2105.10860.pdf).
+
+ Args:
+ logits_list (list[paddle.Tensor]): Single-channel segmentation logit maps for each of the two temporal phases.
+ labels (paddle.Tensor): Binary change labels.
+ """
+
+ # Create loss
+ criterion_ssl = DiceBCELoss()
+
+ # Get downsampled change map
+ h, w = logits_list[0].shape[-2], logits_list[0].shape[-1]
+ labels_downsample = F.interpolate(x=labels.unsqueeze(1), size=[h, w])
+ labels_type = str(labels_downsample.dtype)
+ assert "int" in labels_type or "bool" in labels_type,\
+ f"Expected dtype of labels to be int or bool, but got {labels_type}"
+
+ # Seg map
+ out1 = paddle.nn.functional.sigmoid(logits_list[0]).clone()
+ out2 = paddle.nn.functional.sigmoid(logits_list[1]).clone()
+ out3 = out1.clone()
+ out4 = out2.clone()
+
+ out1 = paddle.where(labels_downsample == 1, paddle.zeros_like(out1), out1)
+ out2 = paddle.where(labels_downsample == 1, paddle.zeros_like(out2), out2)
+ out3 = paddle.where(labels_downsample != 1, paddle.zeros_like(out3), out3)
+ out4 = paddle.where(labels_downsample != 1, paddle.zeros_like(out4), out4)
+
+ pred_seg_pre_tmp1 = paddle.where(out1 <= 0.5,
+ paddle.zeros_like(out1),
+ paddle.ones_like(out1))
+ pred_seg_post_tmp1 = paddle.where(out2 <= 0.5,
+ paddle.zeros_like(out2),
+ paddle.ones_like(out2))
+
+ pred_seg_pre_tmp2 = paddle.where(out3 <= 0.5,
+ paddle.zeros_like(out3),
+ paddle.ones_like(out3))
+ pred_seg_post_tmp2 = paddle.where(out4 <= 0.5,
+ paddle.zeros_like(out4),
+ paddle.ones_like(out4))
+
+ # Seg loss
+ labels_downsample = labels_downsample.astype(paddle.float32)
+ loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False)
+ loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False)
+ loss_aux += 0.2 * criterion_ssl(
+ out3, labels_downsample - pred_seg_post_tmp2, False)
+ loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
+ False)
+
+ return loss_aux
diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py
index 34ed61d..8c2f477 100644
--- a/paddlers/tasks/change_detector.py
+++ b/paddlers/tasks/change_detector.py
@@ -38,7 +38,7 @@ from .utils.infer_nets import InferCDNet
__all__ = [
"CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
- "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer"
+ "SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer", "FCCDN"
]
@@ -1062,7 +1062,7 @@ class ChangeStar(BaseChangeDetector):
if self.use_mixed_loss is False:
return {
# XXX: make sure the shallow copy works correctly here.
- 'types': [seglosses.CrossEntropyLoss()] * 4,
+ 'types': [seg_losses.CrossEntropyLoss()] * 4,
'coef': [1.0] * 4
}
else:
@@ -1089,3 +1089,31 @@ class ChangeFormer(BaseChangeDetector):
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
+
+
+class FCCDN(BaseChangeDetector):
+ def __init__(self,
+ in_channels=3,
+ num_classes=2,
+ use_mixed_loss=False,
+ losses=None,
+ **params):
+ params.update({'in_channels': in_channels})
+ super(FCCDN, self).__init__(
+ model_name='FCCDN',
+ num_classes=num_classes,
+ use_mixed_loss=use_mixed_loss,
+ losses=losses,
+ **params)
+
+ def default_loss(self):
+ if self.use_mixed_loss is False:
+ return {
+ 'types':
+ [seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss],
+ 'coef': [1.0, 1.0]
+ }
+ else:
+ raise ValueError(
+ f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
+ )
diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py
index db6555a..b9c586f 100644
--- a/paddlers/tasks/segmenter.py
+++ b/paddlers/tasks/segmenter.py
@@ -827,7 +827,7 @@ class DeepLabV3P(BaseSegmenter):
if params.get('with_net', True):
with DisablePrint():
backbone = getattr(ppseg.models, backbone)(
- input_channel=input_channel, output_stride=output_stride)
+ input_channel=in_channels, output_stride=output_stride)
else:
backbone = None
params.update({
diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py
index e1fef8a..7e2eac7 100644
--- a/paddlers/transforms/operators.py
+++ b/paddlers/transforms/operators.py
@@ -914,9 +914,9 @@ class Normalize(Transform):
std (list[float] | tuple[float], optional): Standard deviation of input
image(s). Defaults to [0.229, 0.224, 0.225].
min_val (list[float] | tuple[float], optional): Minimum value of input
- image(s). Defaults to [0, 0, 0, ].
- max_val (list[float] | tuple[float], optional): Max value of input image(s).
- Defaults to [255., 255., 255.].
+ image(s). If None, use 0 for all channels. Defaults to None.
+ max_val (list[float] | tuple[float], optional): Maximum value of input
+ image(s). If None, use 255. for all channels. Defaults to None.
apply_to_tar (bool, optional): Whether to apply transformation to the target
image. Defaults to True.
"""
diff --git a/test_tipc/configs/cd/fccdn/fccdn.yaml b/test_tipc/configs/cd/fccdn/fccdn.yaml
new file mode 100644
index 0000000..8b93717
--- /dev/null
+++ b/test_tipc/configs/cd/fccdn/fccdn.yaml
@@ -0,0 +1,13 @@
+# Basic configurations of FCCDN
+
+_base_: ../_base_/airchange.yaml
+
+save_dir: ./test_tipc/output/cd/fccdn/
+
+model: !Node
+ type: FCCDN
+
+learning_rate: 0.07
+lr_decay_power: 0.6
+log_interval_steps: 100
+save_interval_epochs: 3
diff --git a/test_tipc/configs/cd/fccdn/train_infer_python.txt b/test_tipc/configs/cd/fccdn/train_infer_python.txt
new file mode 100644
index 0000000..26147e5
--- /dev/null
+++ b/test_tipc/configs/cd/fccdn/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:cd:fccdn
+python:python
+gpu_list:0
+use_gpu:null|null
+--precision:null
+--num_epochs:lite_train_lite_infer=15|lite_train_whole_infer=15|whole_train_whole_infer=15
+--save_dir:adaptive
+--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
+--model_path:null
+train_model_name:best_model
+train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
+null:null
+##
+trainer:norm
+norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/fccdn/fccdn.yaml
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================export_params===========================
+--save_dir:adaptive
+--model_dir:adaptive
+--fixed_input_shape:[1,3,256,256]
+norm_export:deploy/export/export_model.py
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+===========================infer_params===========================
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:test_tipc/infer.py
+--device:cpu|gpu
+--enable_mkldnn:True
+--cpu_threads:6
+--batch_size:1
+--use_trt:False
+--precision:fp32
+--model_dir:null
+--file_list:null:null
+--save_log_path:null
+--benchmark:True
+--model_name:fccdn
+null:null
diff --git a/test_tipc/configs/seg/unet/unet.yaml b/test_tipc/configs/seg/unet/unet.yaml
index 8d2af9c..4077bf0 100644
--- a/test_tipc/configs/seg/unet/unet.yaml
+++ b/test_tipc/configs/seg/unet/unet.yaml
@@ -7,5 +7,5 @@ save_dir: ./test_tipc/output/seg/unet/
model: !Node
type: UNet
args:
- input_channel: 10
+ in_channels: 10
num_classes: 5
\ No newline at end of file
diff --git a/tests/rs_models/test_cd_models.py b/tests/rs_models/test_cd_models.py
index f1d1e6e..3b3e683 100644
--- a/tests/rs_models/test_cd_models.py
+++ b/tests/rs_models/test_cd_models.py
@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import platform
from itertools import cycle
import paddlers
-from rs_models.test_model import TestModel
+from rs_models.test_model import TestModel, allow_oom
__all__ = [
'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel',
@@ -200,6 +199,7 @@ class TestSNUNetModel(TestCDModel):
] # yapf: disable
+@allow_oom
class TestSTANetModel(TestCDModel):
MODEL_CLASS = paddlers.rs_models.cd.STANet
@@ -214,6 +214,7 @@ class TestSTANetModel(TestCDModel):
] # yapf: disable
+@allow_oom
class TestChangeFormerModel(TestCDModel):
MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer
@@ -224,9 +225,3 @@ class TestChangeFormerModel(TestCDModel):
dict(**base_spec, decoder_softmax=True),
dict(**base_spec, embed_dim=56)
] # yapf: disable
-
-
-# HACK:FIXME: We observe an OOM error when running TestSTANetModel.test_forward() on a Windows machine.
-# Currently, we do not perform this test.
-if platform.system() == 'Windows':
- TestSTANetModel.test_forward = lambda self: None
diff --git a/tests/rs_models/test_model.py b/tests/rs_models/test_model.py
index 06c4777..cd1d623 100644
--- a/tests/rs_models/test_model.py
+++ b/tests/rs_models/test_model.py
@@ -18,6 +18,7 @@ import paddle
import numpy as np
from paddle.static import InputSpec
+from paddlers.utils import logging
from testing_utils import CommonTest
@@ -37,20 +38,26 @@ class _TestModelNamespace:
for i, (
input, model, target
) in enumerate(zip(self.inputs, self.models, self.targets)):
- with self.subTest(i=i):
+ try:
if isinstance(input, list):
output = model(*input)
else:
output = model(input)
self.check_output(output, target)
+ except:
+ logging.warning(f"Model built with spec{i} failed!")
+ raise
def test_to_static(self):
for i, (
input, model, target
) in enumerate(zip(self.inputs, self.models, self.targets)):
- with self.subTest(i=i):
+ try:
static_model = paddle.jit.to_static(
model, input_spec=self.get_input_spec(model, input))
+ except:
+ logging.warning(f"Model built with spec{i} failed!")
+ raise
def check_output(self, output, target):
pass
@@ -117,4 +124,29 @@ class _TestModelNamespace:
return input_spec
+def allow_oom(cls):
+ def _deco(func):
+ def _wrapper(self, *args, **kwargs):
+ try:
+ func(self, *args, **kwargs)
+ except (SystemError, RuntimeError, OSError, MemoryError) as e:
+ # XXX: This may not cover all OOM cases.
+ msg = str(e)
+ if "Out of memory error" in msg \
+ or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg \
+ or isinstance(e, MemoryError):
+ logging.warning("An OOM error has been ignored.")
+ else:
+ raise
+
+ return _wrapper
+
+ for key, value in inspect.getmembers(cls):
+ if key.startswith('test'):
+ value = _deco(value)
+ setattr(cls, key, value)
+
+ return cls
+
+
TestModel = _TestModelNamespace.TestModel
diff --git a/tests/run_ci_dev.sh b/tests/run_ci_dev.sh
new file mode 100644
index 0000000..d789511
--- /dev/null
+++ b/tests/run_ci_dev.sh
@@ -0,0 +1,41 @@
+#!/bin bash
+
+rm -rf /usr/local/python2.7.15/bin/python
+rm -rf /usr/local/python2.7.15/bin/pip
+ln -s /usr/local/bin/python3.7 /usr/local/python2.7.15/bin/python
+ln -s /usr/local/bin/pip3.7 /usr/local/python2.7.15/bin/pip
+export PYTHONPATH=`pwd`
+
+python -m pip install --upgrade pip --ignore-installed
+# python -m pip install --upgrade numpy --ignore-installed
+python -m pip uninstall paddlepaddle-gpu -y
+if [[ ${branch} == 'develop' ]];then
+echo "checkout develop !"
+python -m pip install ${paddle_dev} --no-cache-dir
+else
+echo "checkout release !"
+python -m pip install ${paddle_release} --no-cache-dir
+fi
+
+echo -e '*****************paddle_version*****'
+python -c 'import paddle;print(paddle.version.commit)'
+echo -e '*****************paddleseg_version****'
+git rev-parse HEAD
+
+pip install -r requirements.txt --ignore-installed
+pip install -e .
+pip install https://versaweb.dl.sourceforge.net/project/gdal-wheels-for-linux/GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl
+
+git clone https://github.com/LDOUBLEV/AutoLog
+cd AutoLog
+pip install -r requirements.txt
+python setup.py bdist_wheel
+pip install ./dist/auto_log*.whl
+cd ..
+
+unset http_proxy https_proxy
+
+set -e
+
+cd tests/
+bash run_fast_tests.sh
diff --git a/tests/run_tipc_lite.sh b/tests/run_tipc_lite.sh
new file mode 100644
index 0000000..50a8f9f
--- /dev/null
+++ b/tests/run_tipc_lite.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+
+cd ..
+
+for config in $(ls test_tipc/configs/*/*/train_infer_python.txt); do
+ bash test_tipc/prepare.sh ${config} lite_train_lite_infer
+ bash test_tipc/test_train_inference_python.sh ${config} lite_train_lite_infer
+ task="$(basename $(dirname $(dirname ${config})))"
+ model="$(basename $(dirname ${config}))"
+ if grep -q 'failed' "test_tipc/output/${task}/${model}/lite_train_lite_infer/results_python.log"; then
+ exit 1
+ fi
+done
\ No newline at end of file
diff --git a/tutorials/train/change_detection/fccdn.py b/tutorials/train/change_detection/fccdn.py
new file mode 100644
index 0000000..62abbba
--- /dev/null
+++ b/tutorials/train/change_detection/fccdn.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+
+# 变化检测模型FCCDN训练示例脚本
+# 执行此脚本前,请确认已正确安装PaddleRS库
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 数据集存放目录
+DATA_DIR = './data/airchange/'
+# 训练集`file_list`文件路径
+TRAIN_FILE_LIST_PATH = './data/airchange/train.txt'
+# 验证集`file_list`文件路径
+EVAL_FILE_LIST_PATH = './data/airchange/eval.txt'
+# 实验目录,保存输出的模型权重和结果
+EXP_DIR = './output/fccdn/'
+
+# 下载和解压AirChange数据集
+pdrs.utils.download_and_decompress(
+ 'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/')
+
+# 定义训练和验证时使用的数据变换(数据增强、预处理等)
+# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
+# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
+train_transforms = T.Compose([
+ # 读取影像
+ T.DecodeImg(),
+ # 随机裁剪
+ T.RandomCrop(
+ # 裁剪区域将被缩放到256x256
+ crop_size=256,
+ # 裁剪区域的横纵比在0.5-2之间变动
+ aspect_ratio=[0.5, 2.0],
+ # 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5
+ scaling=[0.2, 1.0]),
+ # 以50%的概率实施随机水平翻转
+ T.RandomHorizontalFlip(prob=0.5),
+ # 将数据归一化到[-1,1]
+ T.Normalize(
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ T.ArrangeChangeDetector('train')
+])
+
+eval_transforms = T.Compose([
+ T.DecodeImg(),
+ # 验证阶段与训练阶段的数据归一化方式必须相同
+ T.Normalize(
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ T.ReloadMask(),
+ T.ArrangeChangeDetector('eval')
+])
+
+# 分别构建训练和验证所用的数据集
+train_dataset = pdrs.datasets.CDDataset(
+ data_dir=DATA_DIR,
+ file_list=TRAIN_FILE_LIST_PATH,
+ label_list=None,
+ transforms=train_transforms,
+ num_workers=0,
+ shuffle=True,
+ with_seg_labels=False,
+ binarize_labels=True)
+
+eval_dataset = pdrs.datasets.CDDataset(
+ data_dir=DATA_DIR,
+ file_list=EVAL_FILE_LIST_PATH,
+ label_list=None,
+ transforms=eval_transforms,
+ num_workers=0,
+ shuffle=False,
+ with_seg_labels=False,
+ binarize_labels=True)
+
+# 使用默认参数构建FCCDN模型
+# 目前已支持的模型及模型输入参数请参考:
+# https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
+model = pdrs.tasks.cd.FCCDN()
+
+# 执行模型训练
+model.train(
+ num_epochs=5,
+ train_dataset=train_dataset,
+ train_batch_size=4,
+ eval_dataset=eval_dataset,
+ save_interval_epochs=2,
+ # 每多少次迭代记录一次日志
+ log_interval_steps=50,
+ save_dir=EXP_DIR,
+ # 是否使用early stopping策略,当精度不再改善时提前终止训练
+ early_stop=False,
+ # 是否启用VisualDL日志功能
+ use_vdl=True,
+ # 指定从某个检查点继续训练
+ resume_checkpoint=None)