Use common layers

own
Bobholamovic 2 years ago
parent 314383abb7
commit a607c53151
  1. 6
      paddlers/rs_models/cd/bit.py
  2. 46
      paddlers/rs_models/cd/cdnet.py
  3. 4
      paddlers/rs_models/cd/dsamnet.py
  4. 18
      paddlers/rs_models/cd/dsifn.py
  5. 2
      paddlers/rs_models/cd/fc_ef.py
  6. 2
      paddlers/rs_models/cd/fc_siam_conc.py
  7. 2
      paddlers/rs_models/cd/fc_siam_diff.py
  8. 2
      paddlers/rs_models/cd/layers/__init__.py
  9. 6
      paddlers/rs_models/cd/snunet.py
  10. 8
      paddlers/rs_models/cd/stanet.py
  11. 15
      paddlers/rs_models/clas/condensenet_v2.py
  12. 13
      paddlers/rs_models/layers/__init__.py
  13. 23
      paddlers/rs_models/layers/blocks.py
  14. 38
      paddlers/rs_models/seg/farseg.py
  15. 79
      paddlers/rs_models/seg/layers/layers_lib.py

@ -18,7 +18,7 @@ import paddle.nn.functional as F
from paddle.nn.initializer import Normal
from .backbones import resnet
from .layers import Conv3x3, Conv1x1, get_norm_layer, Identity
from .layers import Conv3x3, Conv1x1, get_bn_layer, Identity
from .param_init import KaimingInitMixin
@ -369,12 +369,12 @@ class Backbone(nn.Layer, KaimingInitMixin):
self.resnet = resnet.resnet18(
pretrained=pretrained,
strides=strides,
norm_layer=get_norm_layer())
norm_layer=get_bn_layer())
elif arch == 'resnet34':
self.resnet = resnet.resnet34(
pretrained=pretrained,
strides=strides,
norm_layer=get_norm_layer())
norm_layer=get_bn_layer())
else:
raise ValueError

@ -15,10 +15,25 @@
import paddle
import paddle.nn as nn
from .layers import Conv7x7
class CDNet(nn.Layer):
def __init__(self, in_channels=6, num_classes=2):
"""
The CDNet implementation based on PaddlePaddle.
The original article refers to
Pablo F. Alcantarilla, et al., "Street-View Change Detection with Deconvolutional Networks"
(https://link.springer.com/article/10.1007/s10514-018-9734-5).
Args:
in_channels (int): The number of bands of the input images.
num_classes (int): The number of target classes.
"""
def __init__(self, in_channels, num_classes):
super(CDNet, self).__init__()
self.conv1 = Conv7x7(in_channels, 64, norm=True, act=True)
self.pool1 = nn.MaxPool2D(2, 2, return_mask=True)
self.conv2 = Conv7x7(64, 64, norm=True, act=True)
@ -28,6 +43,7 @@ class CDNet(nn.Layer):
self.conv4 = Conv7x7(64, 64, norm=True, act=True)
self.pool4 = nn.MaxPool2D(2, 2, return_mask=True)
self.conv5 = Conv7x7(64, 64, norm=True, act=True)
self.upool4 = nn.MaxUnPool2D(2, 2)
self.conv6 = Conv7x7(64, 64, norm=True, act=True)
self.upool3 = nn.MaxUnPool2D(2, 2)
@ -39,37 +55,15 @@ class CDNet(nn.Layer):
def forward(self, t1, t2):
x = paddle.concat([t1, t2], axis=1)
x, ind1 = self.pool1(self.conv1(x))
x, ind2 = self.pool2(self.conv2(x))
x, ind3 = self.pool3(self.conv3(x))
x, ind4 = self.pool4(self.conv4(x))
x = self.conv5(self.upool4(x, ind4))
x = self.conv6(self.upool3(x, ind3))
x = self.conv7(self.upool2(x, ind2))
x = self.conv8(self.upool1(x, ind1))
return [self.conv_out(x)]
class Conv7x7(nn.Layer):
def __init__(self, in_ch, out_ch, norm=False, act=False):
super(Conv7x7, self).__init__()
layers = [
nn.Pad2D(3), nn.Conv2D(
in_ch, out_ch, 7, bias_attr=(False if norm else None))
]
if norm:
layers.append(nn.BatchNorm2D(out_ch))
if act:
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
if __name__ == "__main__":
t1 = paddle.randn((1, 3, 512, 512), dtype="float32")
t2 = paddle.randn((1, 3, 512, 512), dtype="float32")
model = CDNet(6, 2)
pred = model(t1, t2)[0]
print(pred.shape)
return [self.conv_out(x)]

@ -16,7 +16,7 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import make_norm, Conv3x3, CBAM
from .layers import make_bn, Conv3x3, CBAM
from .stanet import Backbone, Decoder
@ -93,7 +93,7 @@ class DSLayer(nn.Sequential):
super(DSLayer, self).__init__(
nn.Conv2DTranspose(
in_ch, itm_ch, kernel_size=3, padding=1, **convd_kwargs),
make_norm(itm_ch),
make_bn(itm_ch),
nn.ReLU(),
nn.Dropout2D(p=0.2),
nn.Conv2DTranspose(

@ -20,7 +20,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg16
from .layers import Conv1x1, make_norm, ChannelAttention, SpatialAttention
from .layers import Conv1x1, make_bn, ChannelAttention, SpatialAttention
class DSIFN(nn.Layer):
@ -52,19 +52,19 @@ class DSIFN(nn.Layer):
self.sa5 = SpatialAttention()
self.ca1 = ChannelAttention(in_ch=1024)
self.bn_ca1 = make_norm(1024)
self.bn_ca1 = make_bn(1024)
self.o1_conv1 = conv2d_bn(1024, 512, use_dropout)
self.o1_conv2 = conv2d_bn(512, 512, use_dropout)
self.bn_sa1 = make_norm(512)
self.bn_sa1 = make_bn(512)
self.o1_conv3 = Conv1x1(512, num_classes)
self.trans_conv1 = nn.Conv2DTranspose(512, 512, kernel_size=2, stride=2)
self.ca2 = ChannelAttention(in_ch=1536)
self.bn_ca2 = make_norm(1536)
self.bn_ca2 = make_bn(1536)
self.o2_conv1 = conv2d_bn(1536, 512, use_dropout)
self.o2_conv2 = conv2d_bn(512, 256, use_dropout)
self.o2_conv3 = conv2d_bn(256, 256, use_dropout)
self.bn_sa2 = make_norm(256)
self.bn_sa2 = make_bn(256)
self.o2_conv4 = Conv1x1(256, num_classes)
self.trans_conv2 = nn.Conv2DTranspose(256, 256, kernel_size=2, stride=2)
@ -72,7 +72,7 @@ class DSIFN(nn.Layer):
self.o3_conv1 = conv2d_bn(768, 256, use_dropout)
self.o3_conv2 = conv2d_bn(256, 128, use_dropout)
self.o3_conv3 = conv2d_bn(128, 128, use_dropout)
self.bn_sa3 = make_norm(128)
self.bn_sa3 = make_bn(128)
self.o3_conv4 = Conv1x1(128, num_classes)
self.trans_conv3 = nn.Conv2DTranspose(128, 128, kernel_size=2, stride=2)
@ -80,7 +80,7 @@ class DSIFN(nn.Layer):
self.o4_conv1 = conv2d_bn(384, 128, use_dropout)
self.o4_conv2 = conv2d_bn(128, 64, use_dropout)
self.o4_conv3 = conv2d_bn(64, 64, use_dropout)
self.bn_sa4 = make_norm(64)
self.bn_sa4 = make_bn(64)
self.o4_conv4 = Conv1x1(64, num_classes)
self.trans_conv4 = nn.Conv2DTranspose(64, 64, kernel_size=2, stride=2)
@ -88,7 +88,7 @@ class DSIFN(nn.Layer):
self.o5_conv1 = conv2d_bn(192, 64, use_dropout)
self.o5_conv2 = conv2d_bn(64, 32, use_dropout)
self.o5_conv3 = conv2d_bn(32, 16, use_dropout)
self.bn_sa5 = make_norm(16)
self.bn_sa5 = make_bn(16)
self.o5_conv4 = Conv1x1(16, num_classes)
self.init_weight()
@ -211,7 +211,7 @@ def conv2d_bn(in_ch, out_ch, with_dropout=True):
nn.Conv2D(
in_ch, out_ch, kernel_size=3, stride=1, padding=1),
nn.PReLU(),
make_norm(out_ch),
make_bn(out_ch),
]
if with_dropout:
lst.append(nn.Dropout(p=0.6))

@ -26,7 +26,7 @@ class FCEarlyFusion(nn.Layer):
The FC-EF implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
Rodrigo Caye Daudt, et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:

@ -26,7 +26,7 @@ class FCSiamConc(nn.Layer):
The FC-Siam-conc implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
Rodrigo Caye Daudt, et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:

@ -26,7 +26,7 @@ class FCSiamDiff(nn.Layer):
The FC-Siam-diff implementation based on PaddlePaddle.
The original article refers to
Caye Daudt, R., et al. "Fully convolutional siamese networks for change detection"
Rodrigo Caye Daudt, et al. "Fully convolutional siamese networks for change detection"
(https://arxiv.org/abs/1810.08462).
Args:

@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .blocks import *
from ...layers.blocks import *
from .attention import ChannelAttention, SpatialAttention, CBAM

@ -18,7 +18,7 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import Conv1x1, MaxPool2x2, make_norm, ChannelAttention
from .layers import Conv1x1, MaxPool2x2, make_bn, ChannelAttention
from .param_init import KaimingInitMixin
@ -145,9 +145,9 @@ class ConvBlockNested(nn.Layer):
super(ConvBlockNested, self).__init__()
self.act = nn.ReLU()
self.conv1 = nn.Conv2D(in_ch, mid_ch, kernel_size=3, padding=1)
self.bn1 = make_norm(mid_ch)
self.bn1 = make_bn(mid_ch)
self.conv2 = nn.Conv2D(mid_ch, out_ch, kernel_size=3, padding=1)
self.bn2 = make_norm(out_ch)
self.bn2 = make_bn(out_ch)
def forward(self, x):
x = self.conv1(x)

@ -17,7 +17,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from .backbones import resnet
from .layers import Conv1x1, Conv3x3, get_norm_layer, Identity
from .layers import Conv1x1, Conv3x3, get_bn_layer, Identity
from .param_init import KaimingInitMixin
@ -100,17 +100,17 @@ class Backbone(nn.Layer, KaimingInitMixin):
self.resnet = resnet.resnet18(
pretrained=pretrained,
strides=strides,
norm_layer=get_norm_layer())
norm_layer=get_bn_layer())
elif arch == 'resnet34':
self.resnet = resnet.resnet34(
pretrained=pretrained,
strides=strides,
norm_layer=get_norm_layer())
norm_layer=get_bn_layer())
elif arch == 'resnet50':
self.resnet = resnet.resnet50(
pretrained=pretrained,
strides=strides,
norm_layer=get_norm_layer())
norm_layer=get_bn_layer())
else:
raise ValueError

@ -20,6 +20,8 @@ Apache License [see LICENSE for details]
import paddle
import paddle.nn as nn
from ...layers.blocks import make_bn
__all__ = ["CondenseNetV2_a", "CondenseNetV2_b", "CondenseNetV2_c"]
@ -63,9 +65,7 @@ class Conv(nn.Sequential):
activation="ReLU",
bn_momentum=0.9, ):
super(Conv, self).__init__()
self.add_sublayer(
"norm", nn.BatchNorm2D(
in_channels, momentum=bn_momentum))
self.add_sublayer("norm", make_bn(in_channels, momentum=bn_momentum))
if activation == "ReLU":
self.add_sublayer("activation", nn.ReLU())
elif activation == "HS":
@ -122,7 +122,7 @@ class CondenseLGC(nn.Layer):
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.norm = nn.BatchNorm2D(self.in_channels)
self.norm = make_bn(self.in_channels)
if activation == "ReLU":
self.activation = nn.ReLU()
elif activation == "HS":
@ -164,7 +164,7 @@ class CondenseSFR(nn.Layer):
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.norm = nn.BatchNorm2D(self.in_channels)
self.norm = make_bn(self.in_channels)
if activation == "ReLU":
self.activation = nn.ReLU()
elif activation == "HS":
@ -361,8 +361,7 @@ class CondenseNetV2(nn.Layer):
trans = _Transition()
self.features.add_sublayer("transition_%d" % (i + 1), trans)
else:
self.features.add_sublayer("norm_last",
nn.BatchNorm2D(self.num_features))
self.features.add_sublayer("norm_last", make_bn(self.num_features))
self.features.add_sublayer("relu_last", nn.ReLU())
self.features.add_sublayer("pool_last",
nn.AvgPool2D(self.pool_size))
@ -389,7 +388,7 @@ class CondenseNetV2(nn.Layer):
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
nn.initializer.KaimingNormal()(m.weight)
elif isinstance(m, nn.BatchNorm2D):
elif isinstance(m, (nn.BatchNorm2D, nn.SyncBatchNorm)):
nn.initializer.Constant(value=1.0)(m.weight)
nn.initializer.Constant(value=0.0)(m.bias)

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -12,18 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
__all__ = [
'BasicConv', 'Conv1x1', 'Conv3x3', 'Conv7x7', 'MaxPool2x2', 'MaxUnPool2x2',
'ConvTransposed3x3', 'Identity', 'get_norm_layer', 'get_act_layer',
'make_norm', 'make_act'
'ConvTransposed3x3', 'Identity', 'get_bn_layer', 'get_act_layer', 'make_bn',
'make_act'
]
def get_norm_layer():
# TODO: select appropriate norm layer.
def get_bn_layer():
if paddle.get_device() == 'cpu' or os.environ.get('PADDLESEG_EXPORT_STAGE'):
return nn.BatchNorm2D
elif paddle.distributed.ParallelEnv().nranks == 1:
return nn.BatchNorm2D
else:
return nn.SyncBatchNorm
def get_act_layer():
@ -31,8 +38,8 @@ def get_act_layer():
return nn.ReLU
def make_norm(*args, **kwargs):
norm_layer = get_norm_layer()
def make_bn(*args, **kwargs):
norm_layer = get_bn_layer()
return norm_layer(*args, **kwargs)
@ -66,7 +73,7 @@ class BasicConv(nn.Layer):
**kwargs))
if norm:
if norm is True:
norm = make_norm(out_ch)
norm = make_bn(out_ch)
seq.append(norm)
if act:
if act is True:
@ -171,7 +178,7 @@ class ConvTransposed3x3(nn.Layer):
**kwargs))
if norm:
if norm is True:
norm = make_norm(out_ch)
norm = make_bn(out_ch)
seq.append(norm)
if act:
if act is True:

@ -25,7 +25,8 @@ from paddle.vision.models import resnet50
from paddle import nn
import paddle.nn.functional as F
from .layers import (Identity, ConvReLU, kaiming_normal_init, constant_init)
from .layers import (Identity, Conv3x3, Conv1x1, get_bn_layer,
kaiming_normal_init, constant_init)
class FPN(nn.Layer):
@ -35,11 +36,7 @@ class FPN(nn.Layer):
order, and must be consecutive
"""
def __init__(self,
in_channels_list,
out_channels,
conv_block=ConvReLU,
top_blocks=None):
def __init__(self, in_channels_list, out_channels, top_blocks=None):
super(FPN, self).__init__()
inner_blocks = []
@ -47,8 +44,10 @@ class FPN(nn.Layer):
for idx, in_channels in enumerate(in_channels_list, 1):
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1)
inner_block_module = Conv1x1(
in_channels, out_channels, norm=False, act=True)
layer_block_module = Conv3x3(
out_channels, out_channels, norm=False, act=True)
for module in [inner_block_module, layer_block_module]:
for m in module.sublayers():
if isinstance(m, nn.Conv2D):
@ -131,13 +130,11 @@ class SceneRelation(nn.Layer):
self.feature_reencoders = nn.LayerList()
for c in channel_list:
self.content_encoders.append(
nn.Sequential(
nn.Conv2D(c, out_channels, 1),
nn.BatchNorm2D(out_channels), nn.ReLU()))
Conv1x1(
c, out_channels, norm=True, act=True))
self.feature_reencoders.append(
nn.Sequential(
nn.Conv2D(c, out_channels, 1),
nn.BatchNorm2D(out_channels), nn.ReLU()))
Conv1x1(
c, out_channels, norm=True, act=True))
self.normalizer = nn.Sigmoid()
def forward(self, scene_feature, features: list):
@ -170,19 +167,22 @@ class AssymetricDecoder(nn.Layer):
out_channels,
in_feat_output_strides=(4, 8, 16, 32),
out_feat_output_stride=4,
norm_fn=nn.BatchNorm2D,
norm_fn='batch_norm',
num_groups_gn=None):
super(AssymetricDecoder, self).__init__()
if norm_fn == nn.BatchNorm2D:
if norm_fn == 'batch_norm':
norm_fn = get_bn_layer()
norm_fn_args = dict(num_features=out_channels)
elif norm_fn == nn.GroupNorm:
elif norm_fn == 'group_norm':
norm_fn = nn.GroupNorm
if num_groups_gn is None:
raise ValueError(
'When norm_fn is nn.GroupNorm, num_groups_gn is needed.')
'When norm_fn is group_norm, `num_groups_gn` is needed.')
norm_fn_args = dict(
num_groups=num_groups_gn, num_channels=out_channels)
else:
raise ValueError('Type of {} is not support.'.format(type(norm_fn)))
raise ValueError('{} is not a supported normalization type.'.format(
norm_fn))
self.blocks = nn.LayerList()
for in_feat_os in in_feat_output_strides:
num_upsample = int(math.log2(int(in_feat_os))) - int(

@ -16,72 +16,7 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class ConvBNReLU(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super().__init__()
self._conv = nn.Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
if 'data_format' in kwargs:
data_format = kwargs['data_format']
else:
data_format = 'NCHW'
self._batch_norm = nn.BatchNorm2D(out_channels, data_format=data_format)
def forward(self, x):
x = self._conv(x)
x = self._batch_norm(x)
x = F.relu(x)
return x
class ConvBN(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super().__init__()
self._conv = nn.Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
if 'data_format' in kwargs:
data_format = kwargs['data_format']
else:
data_format = 'NCHW'
self._batch_norm = nn.BatchNorm2D(out_channels, data_format=data_format)
def forward(self, x):
x = self._conv(x)
x = self._batch_norm(x)
return x
class ConvReLU(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
padding='same',
**kwargs):
super().__init__()
self._conv = nn.Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
if 'data_format' in kwargs:
data_format = kwargs['data_format']
else:
data_format = 'NCHW'
self._relu = nn.ReLU()
def forward(self, x):
x = self._conv(x)
x = self._relu(x)
return x
from ...layers.blocks import *
class Add(nn.Layer):
@ -95,15 +30,19 @@ class Add(nn.Layer):
class Activation(nn.Layer):
"""
The wrapper of activations.
Args:
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu',
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid',
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax',
'hsigmoid']. Default: None, means identical transformation.
Returns:
A callable object of Activation.
Raises:
KeyError: When parameter `act` is not in the optional range.
Examples:
from paddleseg.models.common.activation import Activation
relu = Activation("relu")
@ -138,11 +77,3 @@ class Activation(nn.Layer):
return self.act_func(x)
else:
return x
class Identity(nn.Layer):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input

Loading…
Cancel
Save