[Re-request][Feature] Add ChangeStar model (#23)
parent
29c904bc5b
commit
099d5ac8c5
3 changed files with 194 additions and 1 deletions
@ -0,0 +1,158 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
from paddlers.datasets.cd_dataset import MaskType |
||||
from paddlers.custom_models.seg.models import FarSeg |
||||
from .layers import Conv3x3, Identity |
||||
|
||||
|
||||
class _ChangeStarBase(nn.Layer): |
||||
|
||||
USE_MULTITASK_DECODER = True |
||||
OUT_TYPES = ( |
||||
MaskType.CD, |
||||
MaskType.CD, |
||||
MaskType.SEG_T1, |
||||
MaskType.SEG_T2 |
||||
) |
||||
|
||||
def __init__( |
||||
self, |
||||
seg_model, |
||||
num_classes, |
||||
mid_channels, |
||||
inner_channels, |
||||
num_convs, |
||||
scale_factor |
||||
): |
||||
super().__init__() |
||||
|
||||
self.extract = seg_model |
||||
self.detect = ChangeMixin( |
||||
in_ch=mid_channels*2, |
||||
out_ch=num_classes, |
||||
mid_ch=inner_channels, |
||||
num_convs=num_convs, |
||||
scale_factor=scale_factor |
||||
) |
||||
self.segment = nn.Sequential( |
||||
Conv3x3(mid_channels, 2), |
||||
nn.UpsamplingBilinear2D(scale_factor=scale_factor) |
||||
) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
x1 = self.extract(t1)[0] |
||||
x2 = self.extract(t2)[0] |
||||
logit12, logit21 = self.detect(x1, x2) |
||||
|
||||
if not self.training: |
||||
logit_list = [logit12] |
||||
else: |
||||
logit1 = self.segment(x1) |
||||
logit2 = self.segment(x2) |
||||
logit_list = [logit12, logit21, logit1, logit2] |
||||
|
||||
return logit_list |
||||
|
||||
def init_weight(self): |
||||
pass |
||||
|
||||
|
||||
class ChangeMixin(nn.Layer): |
||||
def __init__(self, in_ch, out_ch, mid_ch, num_convs, scale_factor): |
||||
super().__init__() |
||||
convs = [Conv3x3(in_ch, mid_ch, norm=True, act=True)] |
||||
convs += [ |
||||
Conv3x3(mid_ch, mid_ch, norm=True, act=True) |
||||
for _ in range(num_convs-1) |
||||
] |
||||
self.detect = nn.Sequential( |
||||
*convs, |
||||
Conv3x3(mid_ch, out_ch), |
||||
nn.UpsamplingBilinear2D(scale_factor=scale_factor) |
||||
) |
||||
|
||||
def forward(self, x1, x2): |
||||
pred12 = self.detect(paddle.concat([x1, x2], axis=1)) |
||||
pred21 = self.detect(paddle.concat([x2, x1], axis=1)) |
||||
return pred12, pred21 |
||||
|
||||
|
||||
class ChangeStar_FarSeg(_ChangeStarBase): |
||||
""" |
||||
The ChangeStar implementation with a FarSeg encoder based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery" |
||||
(https://arxiv.org/abs/2108.07002). |
||||
|
||||
Note that this implementation differs from the original code in two aspects: |
||||
1. The encoder of the FarSeg model is ResNet50. |
||||
2. We use conv-bn-relu instead of conv-relu-bn. |
||||
|
||||
Args: |
||||
num_classes (int): The number of target classes. |
||||
mid_channels (int, optional): The number of channels required by the ChangeMixin module. Default: 256. |
||||
inner_channels (int, optional): The number of filters used in the convolutional layers in the ChangeMixin module. |
||||
Default: 16. |
||||
num_convs (int, optional): The number of convolutional layers used in the ChangeMixin module. Default: 4. |
||||
scale_factor (float, optional): The scaling factor of the output upsampling layer. Default: 4.0. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
num_classes, |
||||
mid_channels=256, |
||||
inner_channels=16, |
||||
num_convs=4, |
||||
scale_factor=4.0, |
||||
): |
||||
# TODO: Configurable FarSeg model |
||||
class _FarSegWrapper(nn.Layer): |
||||
def __init__(self, seg_model): |
||||
super().__init__() |
||||
self._seg_model = seg_model |
||||
self._seg_model.cls_pred_conv = Identity() |
||||
|
||||
def forward(self, x): |
||||
feat_list = self._seg_model.en(x) |
||||
fpn_feat_list = self._seg_model.fpn(feat_list) |
||||
if self._seg_model.scene_relation: |
||||
c5 = feat_list[-1] |
||||
c6 = self._seg_model.gap(c5) |
||||
refined_fpn_feat_list = self._seg_model.sr(c6, fpn_feat_list) |
||||
else: |
||||
refined_fpn_feat_list = fpn_feat_list |
||||
final_feat = self._seg_model.decoder(refined_fpn_feat_list) |
||||
return [final_feat] |
||||
|
||||
seg_model = FarSeg(out_ch=mid_channels) |
||||
|
||||
super().__init__( |
||||
seg_model=_FarSegWrapper(seg_model), |
||||
num_classes=num_classes, |
||||
mid_channels=mid_channels, |
||||
inner_channels=inner_channels, |
||||
num_convs=num_convs, |
||||
scale_factor=scale_factor |
||||
) |
||||
|
||||
# NOTE: Currently, ChangeStar = FarSeg + ChangeMixin + SegHead |
||||
ChangeStar = ChangeStar_FarSeg |
Loading…
Reference in new issue