diff --git a/paddlers/rs_models/cd/changestar.py b/paddlers/rs_models/cd/changestar.py index f2e6cb6..20c7cc5 100644 --- a/paddlers/rs_models/cd/changestar.py +++ b/paddlers/rs_models/cd/changestar.py @@ -14,7 +14,6 @@ import paddle import paddle.nn as nn -import paddle.nn.functional as F from paddlers.datasets.cd_dataset import MaskType from paddlers.rs_models.seg import FarSeg @@ -22,7 +21,6 @@ from .layers import Conv3x3, Identity class _ChangeStarBase(nn.Layer): - USE_MULTITASK_DECODER = True OUT_TYPES = (MaskType.CD, MaskType.CD, MaskType.SEG_T1, MaskType.SEG_T2) @@ -118,22 +116,12 @@ class ChangeStar_FarSeg(_ChangeStarBase): def __init__(self, seg_model): super(_FarSegWrapper, self).__init__() self._seg_model = seg_model - self._seg_model.cls_pred_conv = Identity() + self._seg_model.cls_head = Identity() def forward(self, x): - feat_list = self._seg_model.en(x) - fpn_feat_list = self._seg_model.fpn(feat_list) - if self._seg_model.scene_relation: - c5 = feat_list[-1] - c6 = self._seg_model.gap(c5) - refined_fpn_feat_list = self._seg_model.sr(c6, - fpn_feat_list) - else: - refined_fpn_feat_list = fpn_feat_list - final_feat = self._seg_model.decoder(refined_fpn_feat_list) - return [final_feat] - - seg_model = FarSeg(out_ch=mid_channels) + return self._seg_model(x) + + seg_model = FarSeg(decoder_out_channels=mid_channels) super(ChangeStar_FarSeg, self).__init__( seg_model=_FarSegWrapper(seg_model), diff --git a/paddlers/rs_models/seg/farseg.py b/paddlers/rs_models/seg/farseg.py index 7a5a62a..8c3e13f 100644 --- a/paddlers/rs_models/seg/farseg.py +++ b/paddlers/rs_models/seg/farseg.py @@ -20,25 +20,79 @@ import math import paddle.nn as nn import paddle.nn.functional as F -from paddle.vision.models import resnet50 -from paddle import nn -import paddle.nn.functional as F +from paddle.vision.models import resnet -from .layers import (Identity, ConvReLU, kaiming_normal_init, constant_init) +from paddlers.models.ppdet.modeling import initializer as init -class FPN(nn.Layer): - """ - Module that adds FPN on top of a list of feature maps. - The feature maps are currently supposed to be in increasing depth - order, and must be consecutive. - """ +class FPNConvBlock(nn.Conv2D): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1): + super(FPNConvBlock, self).__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=dilation * (kernel_size - 1) // 2, + dilation=dilation) + init.kaiming_uniform_(self.weight, a=1) + init.constant_(self.bias, value=0) + +class DefaultConvBlock(nn.Conv2D): def __init__(self, - in_channels_list, + in_channels, out_channels, - conv_block=ConvReLU, - top_blocks=None): + kernel_size, + stride=1, + padding=0, + bias_attr=None): + super(DefaultConvBlock, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias_attr=bias_attr) + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + +class ResNetEncoder(nn.Layer): + def __init__(self, backbone='resnet50', in_channels=3, pretrained=True): + super(ResNetEncoder, self).__init__() + self.resnet = getattr(resnet, backbone)(pretrained=pretrained) + if in_channels != 3: + self.resnet.conv1 = nn.Conv2D( + in_channels, 64, 7, stride=2, padding=3, bias_attr=False) + + for layer in self.resnet.sublayers(): + if isinstance(layer, (nn.BatchNorm2D, nn.SyncBatchNorm)): + layer._momentum = 0.1 + + def forward(self, x): + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x = self.resnet.relu(x) + x = self.resnet.maxpool(x) + + c2 = self.resnet.layer1(x) + c3 = self.resnet.layer2(c2) + c4 = self.resnet.layer3(c3) + c5 = self.resnet.layer4(c4) + + return [c2, c3, c4, c5] + + +class FPN(nn.Layer): + def __init__(self, in_channels_list, out_channels, conv_block=FPNConvBlock): super(FPN, self).__init__() inner_blocks = [] @@ -46,17 +100,10 @@ class FPN(nn.Layer): for idx, in_channels in enumerate(in_channels_list, 1): if in_channels == 0: continue - inner_block_module = conv_block(in_channels, out_channels, 1) - layer_block_module = conv_block(out_channels, out_channels, 3, 1) - for module in [inner_block_module, layer_block_module]: - for m in module.sublayers(): - if isinstance(m, nn.Conv2D): - kaiming_normal_init(m.weight) - inner_blocks.append(inner_block_module) - layer_blocks.append(layer_block_module) + inner_blocks.append(conv_block(in_channels, out_channels, 1)) + layer_blocks.append(conv_block(out_channels, out_channels, 3, 1)) self.inner_blocks = nn.LayerList(inner_blocks) self.layer_blocks = nn.LayerList(layer_blocks) - self.top_blocks = top_blocks def forward(self, x): last_inner = self.inner_blocks[-1](x[-1]) @@ -69,80 +116,55 @@ class FPN(nn.Layer): inner_lateral = inner_block(feature) last_inner = inner_lateral + inner_top_down results.insert(0, layer_block(last_inner)) - if isinstance(self.top_blocks, LastLevelP6P7): - last_results = self.top_blocks(x[-1], results[-1]) - results.extend(last_results) - elif isinstance(self.top_blocks, LastLevelMaxPool): - last_results = self.top_blocks(results[-1]) - results.extend(last_results) return tuple(results) -class LastLevelMaxPool(nn.Layer): - def forward(self, x): - return [F.max_pool2d(x, 1, 2, 0)] - - -class LastLevelP6P7(nn.Layer): - """ - This module is used in RetinaNet to generate extra layers, P6 and P7. - """ - - def __init__(self, in_channels, out_channels): - super(LastLevelP6P7, self).__init__() - self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1) - self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1) - for module in [self.p6, self.p7]: - for m in module.sublayers(): - kaiming_normal_init(m.weight) - constant_init(m.bias, value=0) - self.use_P5 = in_channels == out_channels - - def forward(self, c5, p5): - x = p5 if self.use_P5 else c5 - p6 = self.p6(x) - p7 = self.p7(F.relu(p6)) - return [p6, p7] - - -class SceneRelation(nn.Layer): +class FSRelation(nn.Layer): def __init__(self, in_channels, - channel_list, + channels_list, out_channels, - scale_aware_proj=True): - super(SceneRelation, self).__init__() + scale_aware_proj=True, + conv_block=DefaultConvBlock): + super(FSRelation, self).__init__() + self.scale_aware_proj = scale_aware_proj - if scale_aware_proj: + if self.scale_aware_proj: self.scene_encoder = nn.LayerList([ nn.Sequential( - nn.Conv2D(in_channels, out_channels, 1), - nn.ReLU(), nn.Conv2D(out_channels, out_channels, 1)) - for _ in range(len(channel_list)) + conv_block(in_channels, out_channels, 1), + nn.ReLU(), conv_block(out_channels, out_channels, 1)) + for _ in range(len(channels_list)) ]) else: - # 2mlp self.scene_encoder = nn.Sequential( - nn.Conv2D(in_channels, out_channels, 1), - nn.ReLU(), - nn.Conv2D(out_channels, out_channels, 1), ) + conv_block(in_channels, out_channels, 1), + nn.ReLU(), conv_block(out_channels, out_channels, 1)) + self.content_encoders = nn.LayerList() self.feature_reencoders = nn.LayerList() - for c in channel_list: + for channel in channels_list: self.content_encoders.append( nn.Sequential( - nn.Conv2D(c, out_channels, 1), - nn.BatchNorm2D(out_channels), nn.ReLU())) + conv_block( + channel, out_channels, 1, bias_attr=True), + nn.BatchNorm2D( + out_channels, momentum=0.1), + nn.ReLU())) self.feature_reencoders.append( nn.Sequential( - nn.Conv2D(c, out_channels, 1), - nn.BatchNorm2D(out_channels), nn.ReLU())) + conv_block( + channel, out_channels, 1, bias_attr=True), + nn.BatchNorm2D( + out_channels, momentum=0.1), + nn.ReLU())) + self.normalizer = nn.Sigmoid() - def forward(self, scene_feature, features: list): + def forward(self, scene_feature, feature_list): content_feats = [ c_en(p_feat) - for c_en, p_feat in zip(self.content_encoders, features) + for c_en, p_feat in zip(self.content_encoders, feature_list) ] if self.scale_aware_proj: scene_feats = [op(scene_feature) for op in self.scene_encoder] @@ -157,7 +179,8 @@ class SceneRelation(nn.Layer): for cf in content_feats ] p_feats = [ - op(p_feat) for op, p_feat in zip(self.feature_reencoders, features) + op(p_feat) + for op, p_feat in zip(self.feature_reencoders, feature_list) ] refined_feats = [r * p for r, p in zip(relations, p_feats)] return refined_feats @@ -167,71 +190,40 @@ class AsymmetricDecoder(nn.Layer): def __init__(self, in_channels, out_channels, - in_feat_output_strides=(4, 8, 16, 32), - out_feat_output_stride=4, - norm_fn=nn.BatchNorm2D, - num_groups_gn=None): + in_feature_output_strides=(4, 8, 16, 32), + out_feature_output_stride=4, + conv_block=DefaultConvBlock): super(AsymmetricDecoder, self).__init__() - if norm_fn == nn.BatchNorm2D: - norm_fn_args = dict(num_features=out_channels) - elif norm_fn == nn.GroupNorm: - if num_groups_gn is None: - raise ValueError( - 'When norm_fn is nn.GroupNorm, num_groups_gn is needed.') - norm_fn_args = dict( - num_groups=num_groups_gn, num_channels=out_channels) - else: - raise ValueError('Type of {} is not support.'.format(type(norm_fn))) + self.blocks = nn.LayerList() - for in_feat_os in in_feat_output_strides: - num_upsample = int(math.log2(int(in_feat_os))) - int( - math.log2(int(out_feat_output_stride))) + for in_feature_output_stride in in_feature_output_strides: + num_upsample = int(math.log2(int(in_feature_output_stride))) - int( + math.log2(int(out_feature_output_stride))) num_layers = num_upsample if num_upsample != 0 else 1 self.blocks.append( nn.Sequential(*[ nn.Sequential( - nn.Conv2D( + conv_block( in_channels if idx == 0 else out_channels, out_channels, 3, 1, 1, bias_attr=False), - norm_fn(**norm_fn_args) - if norm_fn is not None else Identity(), + nn.BatchNorm2D( + out_channels, momentum=0.1), nn.ReLU(), nn.UpsamplingBilinear2D(scale_factor=2) if num_upsample - != 0 else Identity(), ) for idx in range(num_layers) + != 0 else nn.Identity(), ) for idx in range(num_layers) ])) - def forward(self, feat_list: list): - inner_feat_list = [] + def forward(self, feature_list): + inner_feature_list = [] for idx, block in enumerate(self.blocks): - decoder_feat = block(feat_list[idx]) - inner_feat_list.append(decoder_feat) - out_feat = sum(inner_feat_list) / 4. - return out_feat - - -class ResNet50Encoder(nn.Layer): - def __init__(self, in_ch=3, pretrained=True): - super(ResNet50Encoder, self).__init__() - self.resnet = resnet50(pretrained=pretrained) - if in_ch != 3: - self.resnet.conv1 = nn.Conv2D( - in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False) - - def forward(self, inputs): - x = inputs - x = self.resnet.conv1(x) - x = self.resnet.bn1(x) - x = self.resnet.relu(x) - x = self.resnet.maxpool(x) - c2 = self.resnet.layer1(x) - c3 = self.resnet.layer2(c2) - c4 = self.resnet.layer3(c3) - c5 = self.resnet.layer4(c4) - return [c2, c3, c4, c5] + decoder_feature = block(feature_list[idx]) + inner_feature_list.append(decoder_feature) + out_feature = sum(inner_feature_list) / len(inner_feature_list) + return out_feature class FarSeg(nn.Layer): @@ -239,50 +231,66 @@ class FarSeg(nn.Layer): The FarSeg implementation based on PaddlePaddle. The original article refers to - Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution - Remote Sensing Imagery" - (https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf) + Zheng Z, Zhong Y, Wang J, et al. Foreground-aware relation network for geospatial object segmentation in + high spatial resolution remote sensing imagery[C]//Proceedings of the IEEE/CVF conference on computer vision + and pattern recognition. 2020: 4096-4105. Args: - in_channels (int, optional): Number of bands of the input images. Default: 3. - num_classes (int, optional): Number of target classes. Default: 16. - fpn_ch_list (list[int]|tuple[int], optional): Channel list of the FPN. Default: (256, 512, 1024, 2048). - mid_ch (int, optional): Output channels of the FPN. Default: 256. - out_ch (int, optional): Output channels of the decoder. Default: 128. - sr_ch_list (list[int]|tuple[int], optional): Channel list of the foreground-scene relation module. Default: (256, 256, 256, 256). - pretrained_encoder (bool, optional): Whether to use a pretrained encoder. Default: True. + in_channels (int): The number of image channels for the input model. Default: 3. + num_classes (int): The unique number of target classes. Default: 16. + backbone (str): A backbone network, models available in `paddle.vision.models.resnet`. Default: resnet50. + backbone_pretrained (bool): Whether the backbone network uses IMAGENET pretrained weights. Default: True. + fpn_out_channels (int): The number of channels output by the feature pyramid network. Default: 256. + fsr_out_channels (int): The number of channels output by the F-S relation module. Default: 256. + scale_aware_proj (bool): Whether to use scale awareness in F-S relation module. Default: True. + decoder_out_channels (int): The number of channels output by the decoder. Default: 128. """ def __init__(self, in_channels=3, num_classes=16, - fpn_ch_list=(256, 512, 1024, 2048), - mid_ch=256, - out_ch=128, - sr_ch_list=(256, 256, 256, 256), - pretrained_encoder=True): + backbone='resnet50', + backbone_pretrained=True, + fpn_out_channels=256, + fsr_out_channels=256, + scale_aware_proj=True, + decoder_out_channels=128): super(FarSeg, self).__init__() - self.en = ResNet50Encoder(in_channels, pretrained_encoder) - self.fpn = FPN(in_channels_list=fpn_ch_list, out_channels=mid_ch) + + backbone = backbone.lower() + self.encoder = ResNetEncoder( + backbone=backbone, + in_channels=in_channels, + pretrained=backbone_pretrained) + + fpn_max_in_channels = 2048 + if backbone in ['resnet18', 'resnet34']: + fpn_max_in_channels = 512 + self.fpn = FPN(in_channels_list=[ + fpn_max_in_channels // (2**(3 - i)) for i in range(4) + ], + out_channels=fpn_out_channels) + self.gap = nn.AdaptiveAvgPool2D(1) + self.fsr = FSRelation( + in_channels=fpn_max_in_channels, + channels_list=[fpn_out_channels] * 4, + out_channels=fsr_out_channels, + scale_aware_proj=scale_aware_proj) + self.decoder = AsymmetricDecoder( - in_channels=mid_ch, out_channels=out_ch) - self.cls_pred_conv = nn.Conv2D(out_ch, num_classes, 1) - self.upsample4x_op = nn.UpsamplingBilinear2D(scale_factor=4) - self.scene_relation = True if sr_ch_list is not None else False - if self.scene_relation: - self.gap = nn.AdaptiveAvgPool2D(1) - self.sr = SceneRelation(fpn_ch_list[-1], sr_ch_list, mid_ch) + in_channels=fsr_out_channels, out_channels=decoder_out_channels) + + self.cls_head = nn.Sequential( + DefaultConvBlock(decoder_out_channels, num_classes, 1), + nn.UpsamplingBilinear2D(scale_factor=4)) def forward(self, x): - feat_list = self.en(x) - fpn_feat_list = self.fpn(feat_list) - if self.scene_relation: - c5 = feat_list[-1] - c6 = self.gap(c5) - refined_fpn_feat_list = self.sr(c6, fpn_feat_list) - else: - refined_fpn_feat_list = fpn_feat_list - final_feat = self.decoder(refined_fpn_feat_list) - cls_pred = self.cls_pred_conv(final_feat) - cls_pred = self.upsample4x_op(cls_pred) - return [cls_pred] + feature_list = self.encoder(x) + + fpn_feature_list = self.fpn(feature_list) + scene_feature = self.gap(feature_list[-1]) + refined_feature_list = self.fsr(scene_feature, fpn_feature_list) + + feature = self.decoder(refined_feature_list) + logit = self.cls_head(feature) + return [logit] diff --git a/tests/rs_models/test_seg_models.py b/tests/rs_models/test_seg_models.py index 156f311..813cfd9 100644 --- a/tests/rs_models/test_seg_models.py +++ b/tests/rs_models/test_seg_models.py @@ -53,10 +53,15 @@ class TestFarSegModel(TestSegModel): def set_specs(self): self.specs = [ - dict(), dict(num_classes=20), dict(pretrained_encoder=False), - dict(in_channels=10) + dict(), dict( + in_channels=6, num_classes=10), dict( + backbone='resnet18', backbone_pretrained=False), dict( + fpn_out_channels=128, + fsr_out_channels=64, + decoder_out_channels=32), dict(scale_aware_proj=False) ] def set_targets(self): - self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(20)], + self.targets = [[self.get_zeros_array(16)], [self.get_zeros_array(10)], + [self.get_zeros_array(16)], [self.get_zeros_array(16)], [self.get_zeros_array(16)]]