[Feature] Update multispectral scene classification (#36)

own
Yizhou Chen 3 years ago committed by GitHub
parent 037d62f379
commit 39c82d943a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 11
      paddlers/custom_models/cd/__init__.py
  2. 0
      paddlers/custom_models/cd/backbones/__init__.py
  3. 0
      paddlers/custom_models/cd/backbones/resnet.py
  4. 0
      paddlers/custom_models/cd/bit.py
  5. 0
      paddlers/custom_models/cd/cdnet.py
  6. 0
      paddlers/custom_models/cd/changestar.py
  7. 0
      paddlers/custom_models/cd/dsamnet.py
  8. 0
      paddlers/custom_models/cd/dsifn.py
  9. 0
      paddlers/custom_models/cd/layers/__init__.py
  10. 0
      paddlers/custom_models/cd/layers/attention.py
  11. 0
      paddlers/custom_models/cd/layers/blocks.py
  12. 24
      paddlers/custom_models/cd/models/__init__.py
  13. 0
      paddlers/custom_models/cd/param_init.py
  14. 0
      paddlers/custom_models/cd/snunet.py
  15. 0
      paddlers/custom_models/cd/stanet.py
  16. 0
      paddlers/custom_models/cd/unet_ef.py
  17. 0
      paddlers/custom_models/cd/unet_siamconc.py
  18. 0
      paddlers/custom_models/cd/unet_siamdiff.py
  19. 2
      paddlers/custom_models/cls/__init__.py
  20. 441
      paddlers/custom_models/cls/condensenet_v2.py
  21. 2
      paddlers/custom_models/seg/__init__.py
  22. 86
      paddlers/custom_models/seg/farseg.py
  23. 1
      paddlers/custom_models/seg/layers/__init__.py
  24. 8
      paddlers/custom_models/seg/layers/layers_lib.py
  25. 0
      paddlers/custom_models/seg/layers/param_init.py
  26. 15
      paddlers/custom_models/seg/models/__init__.py
  27. 15
      paddlers/custom_models/seg/models/farseg/__init__.py
  28. 98
      paddlers/custom_models/seg/models/farseg/fpn.py
  29. 23
      paddlers/custom_models/seg/models/utils/torch_nn.py
  30. 4
      paddlers/datasets/clas_dataset.py
  31. 5
      paddlers/tasks/base.py
  32. 4
      paddlers/tasks/changedetector.py
  33. 32
      paddlers/tasks/classifier.py
  34. 6
      paddlers/tasks/segmenter.py
  35. 2
      paddlers/transforms/functions.py
  36. 49
      tutorials/train/classification/condensenetv2_b_rs_mul.py
  37. 16
      tutorials/train/classification/resnet50_vd_rs.py

@ -12,4 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import models
from .bit import BIT
from .cdnet import CDNet
from .dsifn import DSIFN
from .stanet import STANet
from .snunet import SNUNet
from .dsamnet import DSAMNet
from .changestar import ChangeStar
from .unet_ef import UNetEarlyFusion
from .unet_siamconc import UNetSiamConc
from .unet_siamdiff import UNetSiamDiff

@ -1,24 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bit import BIT
from .cdnet import CDNet
from .dsifn import DSIFN
from .stanet import STANet
from .snunet import SNUNet
from .dsamnet import DSAMNet
from .changestar import ChangeStar
from .unet_ef import UNetEarlyFusion
from .unet_siamconc import UNetSiamConc
from .unet_siamdiff import UNetSiamDiff

@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .condensenet_v2 import CondenseNetV2_a, CondenseNetV2_b, CondenseNetV2_c

@ -0,0 +1,441 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/AgentMaker/Paddle-Image-Models
Ths copyright of AgentMaker/Paddle-Image-Models is as follows:
Apache License [see LICENSE for details]
"""
import paddle
import paddle.nn as nn
__all__ = ["CondenseNetV2_a", "CondenseNetV2_b", "CondenseNetV2_c"]
class SELayer(nn.Layer):
def __init__(self, inplanes, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(
nn.Linear(
inplanes, inplanes // reduction, bias_attr=False),
nn.ReLU(),
nn.Linear(
inplanes // reduction, inplanes, bias_attr=False),
nn.Sigmoid(), )
def forward(self, x):
b, c, _, _ = x.shape
y = self.avg_pool(x).reshape((b, c))
y = self.fc(y).reshape((b, c, 1, 1))
return x * y.expand_as(x)
class HS(nn.Layer):
def __init__(self):
super(HS, self).__init__()
self.relu6 = nn.ReLU6()
def forward(self, inputs):
return inputs * self.relu6(inputs + 3) / 6
class Conv(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
groups=1,
activation="ReLU",
bn_momentum=0.9, ):
super(Conv, self).__init__()
self.add_sublayer(
"norm", nn.BatchNorm2D(
in_channels, momentum=bn_momentum))
if activation == "ReLU":
self.add_sublayer("activation", nn.ReLU())
elif activation == "HS":
self.add_sublayer("activation", HS())
else:
raise NotImplementedError
self.add_sublayer(
"conv",
nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias_attr=False,
groups=groups, ), )
def ShuffleLayer(x, groups):
batchsize, num_channels, height, width = x.shape
channels_per_group = num_channels // groups
# reshape
x = x.reshape((batchsize, groups, channels_per_group, height, width))
# transpose
x = x.transpose((0, 2, 1, 3, 4))
# reshape
x = x.reshape((batchsize, -1, height, width))
return x
def ShuffleLayerTrans(x, groups):
batchsize, num_channels, height, width = x.shape
channels_per_group = num_channels // groups
# reshape
x = x.reshape((batchsize, channels_per_group, groups, height, width))
# transpose
x = x.transpose((0, 2, 1, 3, 4))
# reshape
x = x.reshape((batchsize, -1, height, width))
return x
class CondenseLGC(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
groups=1,
activation="ReLU", ):
super(CondenseLGC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.norm = nn.BatchNorm2D(self.in_channels)
if activation == "ReLU":
self.activation = nn.ReLU()
elif activation == "HS":
self.activation = HS()
else:
raise NotImplementedError
self.conv = nn.Conv2D(
self.in_channels,
self.out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=self.groups,
bias_attr=False, )
self.register_buffer(
"index", paddle.zeros(
(self.in_channels, ), dtype="int64"))
def forward(self, x):
x = paddle.index_select(x, self.index, axis=1)
x = self.norm(x)
x = self.activation(x)
x = self.conv(x)
x = ShuffleLayer(x, self.groups)
return x
class CondenseSFR(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
groups=1,
activation="ReLU", ):
super(CondenseSFR, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.norm = nn.BatchNorm2D(self.in_channels)
if activation == "ReLU":
self.activation = nn.ReLU()
elif activation == "HS":
self.activation = HS()
else:
raise NotImplementedError
self.conv = nn.Conv2D(
self.in_channels,
self.out_channels,
kernel_size=kernel_size,
padding=padding,
groups=self.groups,
bias_attr=False,
stride=stride, )
self.register_buffer("index",
paddle.zeros(
(self.out_channels, self.out_channels)))
def forward(self, x):
x = self.norm(x)
x = self.activation(x)
x = ShuffleLayerTrans(x, self.groups)
x = self.conv(x) # SIZE: N, C, H, W
N, C, H, W = x.shape
x = x.reshape((N, C, H * W))
x = x.transpose((0, 2, 1)) # SIZE: N, HW, C
# x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C
x = paddle.matmul(x, self.index)
x = x.transpose((0, 2, 1)) # SIZE: N, C, HW
x = x.reshape((N, C, H, W)) # SIZE: N, C, HW
return x
class _SFR_DenseLayer(nn.Layer):
def __init__(
self,
in_channels,
growth_rate,
group_1x1,
group_3x3,
group_trans,
bottleneck,
activation,
use_se=False, ):
super(_SFR_DenseLayer, self).__init__()
self.group_1x1 = group_1x1
self.group_3x3 = group_3x3
self.group_trans = group_trans
self.use_se = use_se
# 1x1 conv i --> b*k
self.conv_1 = CondenseLGC(
in_channels,
bottleneck * growth_rate,
kernel_size=1,
groups=self.group_1x1,
activation=activation, )
# 3x3 conv b*k --> k
self.conv_2 = Conv(
bottleneck * growth_rate,
growth_rate,
kernel_size=3,
padding=1,
groups=self.group_3x3,
activation=activation, )
# 1x1 res conv k(8-16-32)--> i (k*l)
self.sfr = CondenseSFR(
growth_rate,
in_channels,
kernel_size=1,
groups=self.group_trans,
activation=activation, )
if self.use_se:
self.se = SELayer(inplanes=growth_rate, reduction=1)
def forward(self, x):
x_ = x
x = self.conv_1(x)
x = self.conv_2(x)
if self.use_se:
x = self.se(x)
sfr_feature = self.sfr(x)
y = x_ + sfr_feature
return paddle.concat([y, x], 1)
class _SFR_DenseBlock(nn.Sequential):
def __init__(
self,
num_layers,
in_channels,
growth_rate,
group_1x1,
group_3x3,
group_trans,
bottleneck,
activation,
use_se, ):
super(_SFR_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _SFR_DenseLayer(
in_channels + i * growth_rate,
growth_rate,
group_1x1,
group_3x3,
group_trans,
bottleneck,
activation,
use_se, )
self.add_sublayer("denselayer_%d" % (i + 1), layer)
class _Transition(nn.Layer):
def __init__(self):
super(_Transition, self).__init__()
self.pool = nn.AvgPool2D(kernel_size=2, stride=2)
def forward(self, x):
x = self.pool(x)
return x
class CondenseNetV2(nn.Layer):
def __init__(
self,
stages,
growth,
HS_start_block,
SE_start_block,
fc_channel,
group_1x1,
group_3x3,
group_trans,
bottleneck,
last_se_reduction,
in_channels=3,
class_num=1000, ):
super(CondenseNetV2, self).__init__()
self.stages = stages
self.growth = growth
self.in_channels = in_channels
self.class_num = class_num
self.last_se_reduction = last_se_reduction
assert len(self.stages) == len(self.growth)
self.progress = 0.0
self.init_stride = 2
self.pool_size = 7
self.features = nn.Sequential()
# Initial nChannels should be 3
self.num_features = 2 * self.growth[0]
# Dense-block 1 (224x224)
self.features.add_sublayer(
"init_conv",
nn.Conv2D(
in_channels,
self.num_features,
kernel_size=3,
stride=self.init_stride,
padding=1,
bias_attr=False, ), )
for i in range(len(self.stages)):
activation = "HS" if i >= HS_start_block else "ReLU"
use_se = True if i >= SE_start_block else False
# Dense-block i
self.add_block(i, group_1x1, group_3x3, group_trans, bottleneck,
activation, use_se)
self.fc = nn.Linear(self.num_features, fc_channel)
self.fc_act = HS()
# Classifier layer
if class_num > 0:
self.classifier = nn.Linear(fc_channel, class_num)
self._initialize()
def add_block(self, i, group_1x1, group_3x3, group_trans, bottleneck,
activation, use_se):
# Check if ith is the last one
last = i == len(self.stages) - 1
block = _SFR_DenseBlock(
num_layers=self.stages[i],
in_channels=self.num_features,
growth_rate=self.growth[i],
group_1x1=group_1x1,
group_3x3=group_3x3,
group_trans=group_trans,
bottleneck=bottleneck,
activation=activation,
use_se=use_se, )
self.features.add_sublayer("denseblock_%d" % (i + 1), block)
self.num_features += self.stages[i] * self.growth[i]
if not last:
trans = _Transition()
self.features.add_sublayer("transition_%d" % (i + 1), trans)
else:
self.features.add_sublayer("norm_last",
nn.BatchNorm2D(self.num_features))
self.features.add_sublayer("relu_last", nn.ReLU())
self.features.add_sublayer("pool_last",
nn.AvgPool2D(self.pool_size))
# if useSE:
self.features.add_sublayer(
"se_last",
SELayer(
self.num_features, reduction=self.last_se_reduction))
def forward(self, x):
features = self.features(x)
out = features.reshape((features.shape[0], -1))
out = self.fc(out)
out = self.fc_act(out)
if self.class_num > 0:
out = self.classifier(out)
return out
def _initialize(self):
# initialize
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
nn.initializer.KaimingNormal()(m.weight)
elif isinstance(m, nn.BatchNorm2D):
nn.initializer.Constant(value=1.0)(m.weight)
nn.initializer.Constant(value=0.0)(m.bias)
def CondenseNetV2_a(**kwargs):
model = CondenseNetV2(
stages=[1, 1, 4, 6, 8],
growth=[8, 8, 16, 32, 64],
HS_start_block=2,
SE_start_block=3,
fc_channel=828,
group_1x1=8,
group_3x3=8,
group_trans=8,
bottleneck=4,
last_se_reduction=16,
**kwargs)
return model
def CondenseNetV2_b(**kwargs):
model = CondenseNetV2(
stages=[2, 4, 6, 8, 6],
growth=[6, 12, 24, 48, 96],
HS_start_block=2,
SE_start_block=3,
fc_channel=1024,
group_1x1=6,
group_3x3=6,
group_trans=6,
bottleneck=4,
last_se_reduction=16,
**kwargs)
return model
def CondenseNetV2_c(**kwargs):
model = CondenseNetV2(
stages=[4, 6, 8, 10, 8],
growth=[8, 16, 32, 64, 128],
HS_start_block=2,
SE_start_block=3,
fc_channel=1024,
group_1x1=8,
group_3x3=8,
group_trans=8,
bottleneck=4,
last_se_reduction=16,
**kwargs)
return model

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import models
from .farseg import FarSeg

@ -21,8 +21,90 @@ import math
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import resnet50
from .fpn import FPN
from ..utils import Identity
from paddle import nn
import paddle.nn.functional as F
from .layers import (Identity, ConvReLU, kaiming_normal_init, constant_init)
class FPN(nn.Layer):
"""
Module that adds FPN on top of a list of feature maps.
The feature maps are currently supposed to be in increasing depth
order, and must be consecutive
"""
def __init__(self,
in_channels_list,
out_channels,
conv_block=ConvReLU,
top_blocks=None):
super(FPN, self).__init__()
self.inner_blocks = []
self.layer_blocks = []
for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1)
self.add_sublayer(inner_block, inner_block_module)
self.add_sublayer(layer_block, layer_block_module)
for module in [inner_block_module, layer_block_module]:
for m in module.sublayers():
if isinstance(m, nn.Conv2D):
kaiming_normal_init(m.weight)
self.inner_blocks.append(inner_block)
self.layer_blocks.append(layer_block)
self.top_blocks = top_blocks
def forward(self, x):
last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
results = [getattr(self, self.layer_blocks[-1])(last_inner)]
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1],
self.layer_blocks[:-1][::-1]):
if not inner_block:
continue
inner_top_down = F.interpolate(
last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature)
last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner))
if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
results.extend(last_results)
elif isinstance(self.top_blocks, LastLevelMaxPool):
last_results = self.top_blocks(results[-1])
results.extend(last_results)
return tuple(results)
class LastLevelMaxPool(nn.Layer):
def forward(self, x):
return [F.max_pool2d(x, 1, 2, 0)]
class LastLevelP6P7(nn.Layer):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
for m in module.sublayers():
kaiming_normal_init(m.weight)
constant_init(m.bias, value=0)
self.use_P5 = in_channels == out_channels
def forward(self, c5, p5):
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
return [p6, p7]
class SceneRelation(nn.Layer):

@ -12,6 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .torch_nn import *
from .param_init import *
from .layers_lib import *

@ -138,3 +138,11 @@ class Activation(nn.Layer):
return self.act_func(x)
else:
return x
class Identity(nn.Layer):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input

@ -1,15 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .farseg import FarSeg

@ -1,15 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .farseg import FarSeg

@ -1,98 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
import paddle.nn.functional as F
from ..utils import (ConvReLU, kaiming_normal_init, constant_init)
class FPN(nn.Layer):
"""
Module that adds FPN on top of a list of feature maps.
The feature maps are currently supposed to be in increasing depth
order, and must be consecutive
"""
def __init__(self,
in_channels_list,
out_channels,
conv_block=ConvReLU,
top_blocks=None):
super(FPN, self).__init__()
self.inner_blocks = []
self.layer_blocks = []
for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1)
self.add_sublayer(inner_block, inner_block_module)
self.add_sublayer(layer_block, layer_block_module)
for module in [inner_block_module, layer_block_module]:
for m in module.sublayers():
if isinstance(m, nn.Conv2D):
kaiming_normal_init(m.weight)
self.inner_blocks.append(inner_block)
self.layer_blocks.append(layer_block)
self.top_blocks = top_blocks
def forward(self, x):
last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
results = [getattr(self, self.layer_blocks[-1])(last_inner)]
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1],
self.layer_blocks[:-1][::-1]):
if not inner_block:
continue
inner_top_down = F.interpolate(
last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature)
last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner))
if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
results.extend(last_results)
elif isinstance(self.top_blocks, LastLevelMaxPool):
last_results = self.top_blocks(results[-1])
results.extend(last_results)
return tuple(results)
class LastLevelMaxPool(nn.Layer):
def forward(self, x):
return [F.max_pool2d(x, 1, 2, 0)]
class LastLevelP6P7(nn.Layer):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
for m in module.sublayers():
kaiming_normal_init(m.weight)
constant_init(m.bias, value=0)
self.use_P5 = in_channels == out_channels
def forward(self, c5, p5):
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
return [p6, p7]

@ -1,23 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
class Identity(nn.Layer):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input

@ -64,10 +64,10 @@ class ClasDataset(Dataset):
"so the space cannot be in the image or label path, but the line[{}] of " \
" file_list[{}] has a space in the image or label path.".format(line, file_list))
items[0] = path_normalization(items[0])
if not is_pic(items[0]):
continue
full_path_im = osp.join(data_dir, items[0])
label = items[1]
if not is_pic(full_path_im):
continue
if not osp.exists(full_path_im):
raise IOError('Image file {} does not exist!'.format(
full_path_im))

@ -39,6 +39,7 @@ from .utils.infer_nets import InferNet
class BaseModel:
def __init__(self, model_type):
self.model_type = model_type
self.in_channels = None
self.num_classes = None
self.labels = None
self.version = paddlers.__version__
@ -130,8 +131,8 @@ class BaseModel:
info['version'] = paddlers.__version__
info['Model'] = self.__class__.__name__
info['_Attributes'] = dict(
[('model_type', self.model_type), ('num_classes', self.num_classes),
('labels', self.labels),
[('model_type', self.model_type), ('in_channels', self.in_channels),
('num_classes', self.num_classes), ('labels', self.labels),
('fixed_input_shape', self.fixed_input_shape),
('best_accuracy', self.best_accuracy),
('best_model_epoch', self.best_model_epoch)])

@ -24,7 +24,7 @@ import paddle.nn.functional as F
from paddle.static import InputSpec
import paddlers
import paddlers.custom_models.cd as cd
import paddlers.custom_models.cd as cmcd
import paddlers.utils.logging as logging
import paddlers.models.ppseg as paddleseg
from paddlers.transforms import arrange_transforms
@ -65,7 +65,7 @@ class BaseChangeDetector(BaseModel):
def build_net(self, **params):
# TODO: add other model
net = cd.models.__dict__[self.model_name](num_classes=self.num_classes,
net = cmcd.__dict__[self.model_name](num_classes=self.num_classes,
**params)
return net

@ -20,6 +20,7 @@ import paddle
import paddle.nn.functional as F
from paddle.static import InputSpec
import paddlers.models.ppcls as paddleclas
import paddlers.custom_models.cls as cmcls
import paddlers
from paddlers.transforms import arrange_transforms
from paddlers.utils import get_single_card_bs, DisablePrint
@ -31,12 +32,15 @@ from paddlers.models.ppcls.data.postprocess import build_postprocess
from paddlers.utils.checkpoint import cls_pretrain_weights_dict
from paddlers.transforms import ImgDecoder, Resize
__all__ = ["ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C"]
__all__ = [
"ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C", "CondenseNetV2_b"
]
class BaseClassifier(BaseModel):
def __init__(self,
model_name,
in_channels=3,
num_classes=2,
use_mixed_loss=False,
**params):
@ -44,10 +48,12 @@ class BaseClassifier(BaseModel):
if 'with_net' in self.init_params:
del self.init_params['with_net']
super(BaseClassifier, self).__init__('classifier')
if not hasattr(paddleclas.arch.backbone, model_name):
if not hasattr(paddleclas.arch.backbone, model_name) and \
not hasattr(cmcls, model_name):
raise Exception("ERROR: There's no model named {}.".format(
model_name))
self.model_name = model_name
self.in_channels = in_channels
self.num_classes = num_classes
self.use_mixed_loss = use_mixed_loss
self.metrics = None
@ -61,8 +67,17 @@ class BaseClassifier(BaseModel):
def build_net(self, **params):
with paddle.utils.unique_name.guard():
net = paddleclas.arch.backbone.__dict__[self.model_name](
class_num=self.num_classes, **params)
model = dict(paddleclas.arch.backbone.__dict__,
**cmcls.__dict__)[self.model_name]
# TODO: Determine whether there is in_channels
try:
net = model(
class_num=self.num_classes,
in_channels=self.in_channels,
**params)
except:
net = model(class_num=self.num_classes, **params)
self.in_channels = 3
return net
def _fix_transforms_shape(self, image_shape):
@ -518,3 +533,12 @@ class HRNet_W18_C(BaseClassifier):
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)
class CondenseNetV2_b(BaseClassifier):
def __init__(self, num_classes=2, use_mixed_loss=False, **params):
super(CondenseNetV2_b, self).__init__(
model_name='CondenseNetV2_b',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)

@ -21,7 +21,7 @@ import paddle
import paddle.nn.functional as F
from paddle.static import InputSpec
import paddlers.models.ppseg as paddleseg
import paddlers.custom_models.seg as seg
import paddlers.custom_models.seg as cmseg
import paddlers
from paddlers.transforms import arrange_transforms
from paddlers.utils import get_single_card_bs, DisablePrint
@ -45,7 +45,7 @@ class BaseSegmenter(BaseModel):
del self.init_params['with_net']
super(BaseSegmenter, self).__init__('segmenter')
if not hasattr(paddleseg.models, model_name) and \
not hasattr(seg.models, model_name):
not hasattr(cmseg, model_name):
raise Exception("ERROR: There's no model named {}.".format(
model_name))
self.model_name = model_name
@ -62,7 +62,7 @@ class BaseSegmenter(BaseModel):
# TODO: when using paddle.utils.unique_name.guard,
# DeepLabv3p and HRNet will raise a error
net = dict(paddleseg.models.__dict__,
**seg.models.__dict__)[self.model_name](
**cmseg.__dict__)[self.model_name](
num_classes=self.num_classes, **params)
return net

@ -315,7 +315,7 @@ def select_bands(im, band_list=[1, 2, 3]):
raise ValueError("The element in band_list must > 1 and <= {}.".
format(str(total_band)))
result.append(im[:, :, band])
ima = np.stack(result, axis=0)
ima = np.stack(result, axis=-1)
return ima

@ -0,0 +1,49 @@
import paddlers as pdrs
from paddlers import transforms as T
# 定义训练和验证时的transforms
train_transforms = T.Compose([
T.BandSelecting([5, 10, 15, 20, 25]), # for tet
T.Resize(target_size=224),
T.RandomHorizontalFlip(),
T.Normalize(
mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5, 0.5]),
])
eval_transforms = T.Compose([
T.BandSelecting([5, 10, 15, 20, 25]),
T.Resize(target_size=224),
T.Normalize(
mean=[0.5, 0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5, 0.5]),
])
# 定义训练和验证所用的数据集
train_dataset = pdrs.datasets.ClasDataset(
data_dir='tutorials/train/classification/DataSet',
file_list='tutorials/train/classification/DataSet/train_list.txt',
label_list='tutorials/train/classification/DataSet/label_list.txt',
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.ClasDataset(
data_dir='tutorials/train/classification/DataSet',
file_list='tutorials/train/classification/DataSet/val_list.txt',
label_list='tutorials/train/classification/DataSet/label_list.txt',
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 初始化模型
num_classes = len(train_dataset.labels)
model = pdrs.tasks.CondenseNetV2_b(in_channels=5, num_classes=num_classes)
# 进行训练
model.train(
num_epochs=100,
pretrain_weights=None,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=3e-4,
save_dir='output/condensenetv2_b')

@ -1,7 +1,3 @@
import sys
sys.path.append("E:/dataFiles/github/PaddleRS")
import paddlers as pdrs
from paddlers import transforms as T
@ -9,7 +5,6 @@ from paddlers import transforms as T
# https://aistudio.baidu.com/aistudio/datasetdetail/63189
# 定义训练和验证时的transforms
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md
train_transforms = T.Compose([
T.Resize(target_size=512),
T.RandomHorizontalFlip(),
@ -24,9 +19,8 @@ eval_transforms = T.Compose([
])
# 定义训练和验证所用的数据集
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md
train_dataset = pdrs.datasets.ClasDataset(
data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/classification/DataSet',
data_dir='tutorials/train/classification/DataSet',
file_list='tutorials/train/classification/DataSet/train_list.txt',
label_list='tutorials/train/classification/DataSet/label_list.txt',
transforms=train_transforms,
@ -34,20 +28,18 @@ train_dataset = pdrs.datasets.ClasDataset(
shuffle=True)
eval_dataset = pdrs.datasets.ClasDataset(
data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/classification/DataSet',
data_dir='tutorials/train/classification/DataSet',
file_list='tutorials/train/classification/DataSet/test_list.txt',
label_list='tutorials/train/classification/DataSet/label_list.txt',
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md
# 初始化模型
num_classes = len(train_dataset.labels)
model = pdrs.tasks.ResNet50_vd(num_classes=num_classes)
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md
# 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md
# 进行训练
model.train(
num_epochs=10,
train_dataset=train_dataset,

Loading…
Cancel
Save