# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle import paddle.nn as nn import paddle.nn.functional as F from paddlers.datasets.cd_dataset import MaskType from paddlers.custom_models.seg import FarSeg from .layers import Conv3x3, Identity class _ChangeStarBase(nn.Layer): USE_MULTITASK_DECODER = True OUT_TYPES = (MaskType.CD, MaskType.CD, MaskType.SEG_T1, MaskType.SEG_T2) def __init__(self, seg_model, num_classes, mid_channels, inner_channels, num_convs, scale_factor): super(_ChangeStarBase, self).__init__() self.extract = seg_model self.detect = ChangeMixin( in_ch=mid_channels * 2, out_ch=num_classes, mid_ch=inner_channels, num_convs=num_convs, scale_factor=scale_factor) self.segment = nn.Sequential( Conv3x3(mid_channels, 2), nn.UpsamplingBilinear2D(scale_factor=scale_factor)) self.init_weight() def forward(self, t1, t2): x1 = self.extract(t1)[0] x2 = self.extract(t2)[0] logit12, logit21 = self.detect(x1, x2) if not self.training: logit_list = [logit12] else: logit1 = self.segment(x1) logit2 = self.segment(x2) logit_list = [logit12, logit21, logit1, logit2] return logit_list def init_weight(self): pass class ChangeMixin(nn.Layer): def __init__(self, in_ch, out_ch, mid_ch, num_convs, scale_factor): super(ChangeMixin, self).__init__() convs = [Conv3x3(in_ch, mid_ch, norm=True, act=True)] convs += [ Conv3x3( mid_ch, mid_ch, norm=True, act=True) for _ in range(num_convs - 1) ] self.detect = nn.Sequential( *convs, Conv3x3(mid_ch, out_ch), nn.UpsamplingBilinear2D(scale_factor=scale_factor)) def forward(self, x1, x2): pred12 = self.detect(paddle.concat([x1, x2], axis=1)) pred21 = self.detect(paddle.concat([x2, x1], axis=1)) return pred12, pred21 class ChangeStar_FarSeg(_ChangeStarBase): """ The ChangeStar implementation with a FarSeg encoder based on PaddlePaddle. The original article refers to Z. Zheng, et al., "Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery" (https://arxiv.org/abs/2108.07002). Note that this implementation differs from the original code in two aspects: 1. The encoder of the FarSeg model is ResNet50. 2. We use conv-bn-relu instead of conv-relu-bn. Args: num_classes (int): The number of target classes. mid_channels (int, optional): The number of channels required by the ChangeMixin module. Default: 256. inner_channels (int, optional): The number of filters used in the convolutional layers in the ChangeMixin module. Default: 16. num_convs (int, optional): The number of convolutional layers used in the ChangeMixin module. Default: 4. scale_factor (float, optional): The scaling factor of the output upsampling layer. Default: 4.0. """ def __init__( self, num_classes, mid_channels=256, inner_channels=16, num_convs=4, scale_factor=4.0, ): # TODO: Configurable FarSeg model class _FarSegWrapper(nn.Layer): def __init__(self, seg_model): super(_FarSegWrapper, self).__init__() self._seg_model = seg_model self._seg_model.cls_pred_conv = Identity() def forward(self, x): feat_list = self._seg_model.en(x) fpn_feat_list = self._seg_model.fpn(feat_list) if self._seg_model.scene_relation: c5 = feat_list[-1] c6 = self._seg_model.gap(c5) refined_fpn_feat_list = self._seg_model.sr(c6, fpn_feat_list) else: refined_fpn_feat_list = fpn_feat_list final_feat = self._seg_model.decoder(refined_fpn_feat_list) return [final_feat] seg_model = FarSeg(out_ch=mid_channels) super(ChangeStar_FarSeg, self).__init__( seg_model=_FarSegWrapper(seg_model), num_classes=num_classes, mid_channels=mid_channels, inner_channels=inner_channels, num_convs=num_convs, scale_factor=scale_factor) # NOTE: Currently, ChangeStar = FarSeg + ChangeMixin + SegHead ChangeStar = ChangeStar_FarSeg