You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
144 lines
5.2 KiB
144 lines
5.2 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
|
|
from paddlers.datasets.cd_dataset import MaskType |
|
from paddlers.custom_models.seg import FarSeg |
|
from .layers import Conv3x3, Identity |
|
|
|
|
|
class _ChangeStarBase(nn.Layer): |
|
|
|
USE_MULTITASK_DECODER = True |
|
OUT_TYPES = (MaskType.CD, MaskType.CD, MaskType.SEG_T1, MaskType.SEG_T2) |
|
|
|
def __init__(self, seg_model, num_classes, mid_channels, inner_channels, |
|
num_convs, scale_factor): |
|
super(_ChangeStarBase, self).__init__() |
|
|
|
self.extract = seg_model |
|
self.detect = ChangeMixin( |
|
in_ch=mid_channels * 2, |
|
out_ch=num_classes, |
|
mid_ch=inner_channels, |
|
num_convs=num_convs, |
|
scale_factor=scale_factor) |
|
self.segment = nn.Sequential( |
|
Conv3x3(mid_channels, 2), |
|
nn.UpsamplingBilinear2D(scale_factor=scale_factor)) |
|
|
|
self.init_weight() |
|
|
|
def forward(self, t1, t2): |
|
x1 = self.extract(t1)[0] |
|
x2 = self.extract(t2)[0] |
|
logit12, logit21 = self.detect(x1, x2) |
|
|
|
if not self.training: |
|
logit_list = [logit12] |
|
else: |
|
logit1 = self.segment(x1) |
|
logit2 = self.segment(x2) |
|
logit_list = [logit12, logit21, logit1, logit2] |
|
|
|
return logit_list |
|
|
|
def init_weight(self): |
|
pass |
|
|
|
|
|
class ChangeMixin(nn.Layer): |
|
def __init__(self, in_ch, out_ch, mid_ch, num_convs, scale_factor): |
|
super(ChangeMixin, self).__init__() |
|
convs = [Conv3x3(in_ch, mid_ch, norm=True, act=True)] |
|
convs += [ |
|
Conv3x3( |
|
mid_ch, mid_ch, norm=True, act=True) |
|
for _ in range(num_convs - 1) |
|
] |
|
self.detect = nn.Sequential( |
|
*convs, |
|
Conv3x3(mid_ch, out_ch), |
|
nn.UpsamplingBilinear2D(scale_factor=scale_factor)) |
|
|
|
def forward(self, x1, x2): |
|
pred12 = self.detect(paddle.concat([x1, x2], axis=1)) |
|
pred21 = self.detect(paddle.concat([x2, x1], axis=1)) |
|
return pred12, pred21 |
|
|
|
|
|
class ChangeStar_FarSeg(_ChangeStarBase): |
|
""" |
|
The ChangeStar implementation with a FarSeg encoder based on PaddlePaddle. |
|
|
|
The original article refers to |
|
Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery" |
|
(https://arxiv.org/abs/2108.07002). |
|
|
|
Note that this implementation differs from the original code in two aspects: |
|
1. The encoder of the FarSeg model is ResNet50. |
|
2. We use conv-bn-relu instead of conv-relu-bn. |
|
|
|
Args: |
|
num_classes (int): The number of target classes. |
|
mid_channels (int, optional): The number of channels required by the ChangeMixin module. Default: 256. |
|
inner_channels (int, optional): The number of filters used in the convolutional layers in the ChangeMixin module. |
|
Default: 16. |
|
num_convs (int, optional): The number of convolutional layers used in the ChangeMixin module. Default: 4. |
|
scale_factor (float, optional): The scaling factor of the output upsampling layer. Default: 4.0. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_classes, |
|
mid_channels=256, |
|
inner_channels=16, |
|
num_convs=4, |
|
scale_factor=4.0, ): |
|
# TODO: Configurable FarSeg model |
|
class _FarSegWrapper(nn.Layer): |
|
def __init__(self, seg_model): |
|
super(_FarSegWrapper, self).__init__() |
|
self._seg_model = seg_model |
|
self._seg_model.cls_pred_conv = Identity() |
|
|
|
def forward(self, x): |
|
feat_list = self._seg_model.en(x) |
|
fpn_feat_list = self._seg_model.fpn(feat_list) |
|
if self._seg_model.scene_relation: |
|
c5 = feat_list[-1] |
|
c6 = self._seg_model.gap(c5) |
|
refined_fpn_feat_list = self._seg_model.sr(c6, |
|
fpn_feat_list) |
|
else: |
|
refined_fpn_feat_list = fpn_feat_list |
|
final_feat = self._seg_model.decoder(refined_fpn_feat_list) |
|
return [final_feat] |
|
|
|
seg_model = FarSeg(out_ch=mid_channels) |
|
|
|
super(ChangeStar_FarSeg, self).__init__( |
|
seg_model=_FarSegWrapper(seg_model), |
|
num_classes=num_classes, |
|
mid_channels=mid_channels, |
|
inner_channels=inner_channels, |
|
num_convs=num_convs, |
|
scale_factor=scale_factor) |
|
|
|
|
|
# NOTE: Currently, ChangeStar = FarSeg + ChangeMixin + SegHead |
|
ChangeStar = ChangeStar_FarSeg
|
|
|