diff --git a/paddlers/rs_models/clas/condensenet_v2.py b/paddlers/rs_models/clas/condensenetv2.py similarity index 95% rename from paddlers/rs_models/clas/condensenet_v2.py rename to paddlers/rs_models/clas/condensenetv2.py index 53bb5aa..ca2b222 100644 --- a/paddlers/rs_models/clas/condensenet_v2.py +++ b/paddlers/rs_models/clas/condensenetv2.py @@ -1,442 +1,442 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This code is based on https://github.com/AgentMaker/Paddle-Image-Models -Ths copyright of AgentMaker/Paddle-Image-Models is as follows: -Apache License [see LICENSE for details] -""" - -import paddle -import paddle.nn as nn - -__all__ = ["CondenseNetV2_a", "CondenseNetV2_b", "CondenseNetV2_c"] - - -class SELayer(nn.Layer): - def __init__(self, inplanes, reduction=16): - super(SELayer, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2D(1) - self.fc = nn.Sequential( - nn.Linear( - inplanes, inplanes // reduction, bias_attr=False), - nn.ReLU(), - nn.Linear( - inplanes // reduction, inplanes, bias_attr=False), - nn.Sigmoid(), ) - - def forward(self, x): - b, c, _, _ = x.shape - y = self.avg_pool(x).reshape((b, c)) - y = self.fc(y).reshape((b, c, 1, 1)) - return x * paddle.expand(y, shape=x.shape) - - -class HS(nn.Layer): - def __init__(self): - super(HS, self).__init__() - self.relu6 = nn.ReLU6() - - def forward(self, inputs): - return inputs * self.relu6(inputs + 3) / 6 - - -class Conv(nn.Sequential): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - groups=1, - activation="ReLU", - bn_momentum=0.9, ): - super(Conv, self).__init__() - self.add_sublayer( - "norm", nn.BatchNorm2D( - in_channels, momentum=bn_momentum)) - if activation == "ReLU": - self.add_sublayer("activation", nn.ReLU()) - elif activation == "HS": - self.add_sublayer("activation", HS()) - else: - raise NotImplementedError - self.add_sublayer( - "conv", - nn.Conv2D( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias_attr=False, - groups=groups, ), ) - - -def ShuffleLayer(x, groups): - batchsize, num_channels, height, width = x.shape - channels_per_group = num_channels // groups - # Reshape - x = x.reshape((batchsize, groups, channels_per_group, height, width)) - # Transpose - x = x.transpose((0, 2, 1, 3, 4)) - # Reshape - x = x.reshape((batchsize, groups * channels_per_group, height, width)) - return x - - -def ShuffleLayerTrans(x, groups): - batchsize, num_channels, height, width = x.shape - channels_per_group = num_channels // groups - # Reshape - x = x.reshape((batchsize, channels_per_group, groups, height, width)) - # Transpose - x = x.transpose((0, 2, 1, 3, 4)) - # Reshape - x = x.reshape((batchsize, channels_per_group * groups, height, width)) - return x - - -class CondenseLGC(nn.Layer): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - groups=1, - activation="ReLU", ): - super(CondenseLGC, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.groups = groups - self.norm = nn.BatchNorm2D(self.in_channels) - if activation == "ReLU": - self.activation = nn.ReLU() - elif activation == "HS": - self.activation = HS() - else: - raise NotImplementedError - self.conv = nn.Conv2D( - self.in_channels, - self.out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=self.groups, - bias_attr=False, ) - self.register_buffer( - "index", paddle.zeros( - (self.in_channels, ), dtype="int64")) - - def forward(self, x): - x = paddle.index_select(x, self.index, axis=1) - x = self.norm(x) - x = self.activation(x) - x = self.conv(x) - x = ShuffleLayer(x, self.groups) - return x - - -class CondenseSFR(nn.Layer): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - groups=1, - activation="ReLU", ): - super(CondenseSFR, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.groups = groups - self.norm = nn.BatchNorm2D(self.in_channels) - if activation == "ReLU": - self.activation = nn.ReLU() - elif activation == "HS": - self.activation = HS() - else: - raise NotImplementedError - self.conv = nn.Conv2D( - self.in_channels, - self.out_channels, - kernel_size=kernel_size, - padding=padding, - groups=self.groups, - bias_attr=False, - stride=stride, ) - self.register_buffer("index", - paddle.zeros( - (self.out_channels, self.out_channels))) - - def forward(self, x): - x = self.norm(x) - x = self.activation(x) - x = ShuffleLayerTrans(x, self.groups) - x = self.conv(x) # SIZE: N, C, H, W - N, C, H, W = x.shape - x = x.reshape((N, C, H * W)) - x = x.transpose((0, 2, 1)) # SIZE: N, HW, C - # x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C - x = paddle.matmul(x, self.index) - x = x.transpose((0, 2, 1)) # SIZE: N, C, HW - x = x.reshape((N, C, H, W)) # SIZE: N, C, HW - return x - - -class _SFR_DenseLayer(nn.Layer): - def __init__( - self, - in_channels, - growth_rate, - group_1x1, - group_3x3, - group_trans, - bottleneck, - activation, - use_se=False, ): - super(_SFR_DenseLayer, self).__init__() - self.group_1x1 = group_1x1 - self.group_3x3 = group_3x3 - self.group_trans = group_trans - self.use_se = use_se - # 1x1 conv i --> b*k - self.conv_1 = CondenseLGC( - in_channels, - bottleneck * growth_rate, - kernel_size=1, - groups=self.group_1x1, - activation=activation, ) - # 3x3 conv b*k --> k - self.conv_2 = Conv( - bottleneck * growth_rate, - growth_rate, - kernel_size=3, - padding=1, - groups=self.group_3x3, - activation=activation, ) - # 1x1 res conv k(8-16-32)--> i (k*l) - self.sfr = CondenseSFR( - growth_rate, - in_channels, - kernel_size=1, - groups=self.group_trans, - activation=activation, ) - if self.use_se: - self.se = SELayer(inplanes=growth_rate, reduction=1) - - def forward(self, x): - x_ = x - x = self.conv_1(x) - x = self.conv_2(x) - if self.use_se: - x = self.se(x) - sfr_feature = self.sfr(x) - y = x_ + sfr_feature - return paddle.concat([y, x], 1) - - -class _SFR_DenseBlock(nn.Sequential): - def __init__( - self, - num_layers, - in_channels, - growth_rate, - group_1x1, - group_3x3, - group_trans, - bottleneck, - activation, - use_se, ): - super(_SFR_DenseBlock, self).__init__() - for i in range(num_layers): - layer = _SFR_DenseLayer( - in_channels + i * growth_rate, - growth_rate, - group_1x1, - group_3x3, - group_trans, - bottleneck, - activation, - use_se, ) - self.add_sublayer("denselayer_%d" % (i + 1), layer) - - -class _Transition(nn.Layer): - def __init__(self): - super(_Transition, self).__init__() - self.pool = nn.AvgPool2D(kernel_size=2, stride=2) - - def forward(self, x): - x = self.pool(x) - return x - - -class CondenseNetV2(nn.Layer): - def __init__( - self, - stages, - growth, - HS_start_block, - SE_start_block, - fc_channel, - group_1x1, - group_3x3, - group_trans, - bottleneck, - last_se_reduction, - in_channels=3, - class_num=1000, ): - super(CondenseNetV2, self).__init__() - self.stages = stages - self.growth = growth - self.in_channels = in_channels - self.class_num = class_num - self.last_se_reduction = last_se_reduction - assert len(self.stages) == len(self.growth) - self.progress = 0.0 - - self.init_stride = 2 - self.pool_size = 7 - - self.features = nn.Sequential() - # Initial nChannels should be 3 - self.num_features = 2 * self.growth[0] - # Dense-block 1 (224x224) - self.features.add_sublayer( - "init_conv", - nn.Conv2D( - in_channels, - self.num_features, - kernel_size=3, - stride=self.init_stride, - padding=1, - bias_attr=False, ), ) - for i in range(len(self.stages)): - activation = "HS" if i >= HS_start_block else "ReLU" - use_se = True if i >= SE_start_block else False - # Dense-block i - self.add_block(i, group_1x1, group_3x3, group_trans, bottleneck, - activation, use_se) - - self.fc = nn.Linear(self.num_features, fc_channel) - self.fc_act = HS() - - # Classifier layer - if class_num > 0: - self.classifier = nn.Linear(fc_channel, class_num) - self._initialize() - - def add_block(self, i, group_1x1, group_3x3, group_trans, bottleneck, - activation, use_se): - # Check if ith is the last one - last = i == len(self.stages) - 1 - block = _SFR_DenseBlock( - num_layers=self.stages[i], - in_channels=self.num_features, - growth_rate=self.growth[i], - group_1x1=group_1x1, - group_3x3=group_3x3, - group_trans=group_trans, - bottleneck=bottleneck, - activation=activation, - use_se=use_se, ) - self.features.add_sublayer("denseblock_%d" % (i + 1), block) - self.num_features += self.stages[i] * self.growth[i] - if not last: - trans = _Transition() - self.features.add_sublayer("transition_%d" % (i + 1), trans) - else: - self.features.add_sublayer("norm_last", - nn.BatchNorm2D(self.num_features)) - self.features.add_sublayer("relu_last", nn.ReLU()) - self.features.add_sublayer("pool_last", - nn.AvgPool2D(self.pool_size)) - # if useSE: - self.features.add_sublayer( - "se_last", - SELayer( - self.num_features, reduction=self.last_se_reduction)) - - def forward(self, x): - features = self.features(x) - out = features.reshape((features.shape[0], features.shape[1] * - features.shape[2] * features.shape[3])) - out = self.fc(out) - out = self.fc_act(out) - - if self.class_num > 0: - out = self.classifier(out) - - return out - - def _initialize(self): - # Initialize - for m in self.sublayers(): - if isinstance(m, nn.Conv2D): - nn.initializer.KaimingNormal()(m.weight) - elif isinstance(m, nn.BatchNorm2D): - nn.initializer.Constant(value=1.0)(m.weight) - nn.initializer.Constant(value=0.0)(m.bias) - - -def CondenseNetV2_a(**kwargs): - model = CondenseNetV2( - stages=[1, 1, 4, 6, 8], - growth=[8, 8, 16, 32, 64], - HS_start_block=2, - SE_start_block=3, - fc_channel=828, - group_1x1=8, - group_3x3=8, - group_trans=8, - bottleneck=4, - last_se_reduction=16, - **kwargs) - return model - - -def CondenseNetV2_b(**kwargs): - model = CondenseNetV2( - stages=[2, 4, 6, 8, 6], - growth=[6, 12, 24, 48, 96], - HS_start_block=2, - SE_start_block=3, - fc_channel=1024, - group_1x1=6, - group_3x3=6, - group_trans=6, - bottleneck=4, - last_se_reduction=16, - **kwargs) - return model - - -def CondenseNetV2_c(**kwargs): - model = CondenseNetV2( - stages=[4, 6, 8, 10, 8], - growth=[8, 16, 32, 64, 128], - HS_start_block=2, - SE_start_block=3, - fc_channel=1024, - group_1x1=8, - group_3x3=8, - group_trans=8, - bottleneck=4, - last_se_reduction=16, - **kwargs) - return model +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is based on https://github.com/AgentMaker/Paddle-Image-Models +Ths copyright of AgentMaker/Paddle-Image-Models is as follows: +Apache License [see LICENSE for details] +""" + +import paddle +import paddle.nn as nn + +__all__ = ["CondenseNetV2_A", "CondenseNetV2_B", "CondenseNetV2_C"] + + +class SELayer(nn.Layer): + def __init__(self, inplanes, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.fc = nn.Sequential( + nn.Linear( + inplanes, inplanes // reduction, bias_attr=False), + nn.ReLU(), + nn.Linear( + inplanes // reduction, inplanes, bias_attr=False), + nn.Sigmoid(), ) + + def forward(self, x): + b, c, _, _ = x.shape + y = self.avg_pool(x).reshape((b, c)) + y = self.fc(y).reshape((b, c, 1, 1)) + return x * paddle.expand(y, shape=x.shape) + + +class HS(nn.Layer): + def __init__(self): + super(HS, self).__init__() + self.relu6 = nn.ReLU6() + + def forward(self, inputs): + return inputs * self.relu6(inputs + 3) / 6 + + +class Conv(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + activation="ReLU", + bn_momentum=0.9, ): + super(Conv, self).__init__() + self.add_sublayer( + "norm", nn.BatchNorm2D( + in_channels, momentum=bn_momentum)) + if activation == "ReLU": + self.add_sublayer("activation", nn.ReLU()) + elif activation == "HS": + self.add_sublayer("activation", HS()) + else: + raise NotImplementedError + self.add_sublayer( + "conv", + nn.Conv2D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias_attr=False, + groups=groups, ), ) + + +def ShuffleLayer(x, groups): + batchsize, num_channels, height, width = x.shape + channels_per_group = num_channels // groups + # Reshape + x = x.reshape((batchsize, groups, channels_per_group, height, width)) + # Transpose + x = x.transpose((0, 2, 1, 3, 4)) + # Reshape + x = x.reshape((batchsize, groups * channels_per_group, height, width)) + return x + + +def ShuffleLayerTrans(x, groups): + batchsize, num_channels, height, width = x.shape + channels_per_group = num_channels // groups + # Reshape + x = x.reshape((batchsize, channels_per_group, groups, height, width)) + # Transpose + x = x.transpose((0, 2, 1, 3, 4)) + # Reshape + x = x.reshape((batchsize, channels_per_group * groups, height, width)) + return x + + +class CondenseLGC(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + activation="ReLU", ): + super(CondenseLGC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.groups = groups + self.norm = nn.BatchNorm2D(self.in_channels) + if activation == "ReLU": + self.activation = nn.ReLU() + elif activation == "HS": + self.activation = HS() + else: + raise NotImplementedError + self.conv = nn.Conv2D( + self.in_channels, + self.out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=self.groups, + bias_attr=False, ) + self.register_buffer( + "index", paddle.zeros( + (self.in_channels, ), dtype="int64")) + + def forward(self, x): + x = paddle.index_select(x, self.index, axis=1) + x = self.norm(x) + x = self.activation(x) + x = self.conv(x) + x = ShuffleLayer(x, self.groups) + return x + + +class CondenseSFR(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + activation="ReLU", ): + super(CondenseSFR, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.groups = groups + self.norm = nn.BatchNorm2D(self.in_channels) + if activation == "ReLU": + self.activation = nn.ReLU() + elif activation == "HS": + self.activation = HS() + else: + raise NotImplementedError + self.conv = nn.Conv2D( + self.in_channels, + self.out_channels, + kernel_size=kernel_size, + padding=padding, + groups=self.groups, + bias_attr=False, + stride=stride, ) + self.register_buffer("index", + paddle.zeros( + (self.out_channels, self.out_channels))) + + def forward(self, x): + x = self.norm(x) + x = self.activation(x) + x = ShuffleLayerTrans(x, self.groups) + x = self.conv(x) # SIZE: N, C, H, W + N, C, H, W = x.shape + x = x.reshape((N, C, H * W)) + x = x.transpose((0, 2, 1)) # SIZE: N, HW, C + # x SIZE: N, HW, C; self.index SIZE: C, C; OUTPUT SIZE: N, HW, C + x = paddle.matmul(x, self.index) + x = x.transpose((0, 2, 1)) # SIZE: N, C, HW + x = x.reshape((N, C, H, W)) # SIZE: N, C, HW + return x + + +class _SFR_DenseLayer(nn.Layer): + def __init__( + self, + in_channels, + growth_rate, + group_1x1, + group_3x3, + group_trans, + bottleneck, + activation, + use_se=False, ): + super(_SFR_DenseLayer, self).__init__() + self.group_1x1 = group_1x1 + self.group_3x3 = group_3x3 + self.group_trans = group_trans + self.use_se = use_se + # 1x1 conv i --> b*k + self.conv_1 = CondenseLGC( + in_channels, + bottleneck * growth_rate, + kernel_size=1, + groups=self.group_1x1, + activation=activation, ) + # 3x3 conv b*k --> k + self.conv_2 = Conv( + bottleneck * growth_rate, + growth_rate, + kernel_size=3, + padding=1, + groups=self.group_3x3, + activation=activation, ) + # 1x1 res conv k(8-16-32)--> i (k*l) + self.sfr = CondenseSFR( + growth_rate, + in_channels, + kernel_size=1, + groups=self.group_trans, + activation=activation, ) + if self.use_se: + self.se = SELayer(inplanes=growth_rate, reduction=1) + + def forward(self, x): + x_ = x + x = self.conv_1(x) + x = self.conv_2(x) + if self.use_se: + x = self.se(x) + sfr_feature = self.sfr(x) + y = x_ + sfr_feature + return paddle.concat([y, x], 1) + + +class _SFR_DenseBlock(nn.Sequential): + def __init__( + self, + num_layers, + in_channels, + growth_rate, + group_1x1, + group_3x3, + group_trans, + bottleneck, + activation, + use_se, ): + super(_SFR_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _SFR_DenseLayer( + in_channels + i * growth_rate, + growth_rate, + group_1x1, + group_3x3, + group_trans, + bottleneck, + activation, + use_se, ) + self.add_sublayer("denselayer_%d" % (i + 1), layer) + + +class _Transition(nn.Layer): + def __init__(self): + super(_Transition, self).__init__() + self.pool = nn.AvgPool2D(kernel_size=2, stride=2) + + def forward(self, x): + x = self.pool(x) + return x + + +class CondenseNetV2(nn.Layer): + def __init__( + self, + stages, + growth, + HS_start_block, + SE_start_block, + fc_channel, + group_1x1, + group_3x3, + group_trans, + bottleneck, + last_se_reduction, + in_channels=3, + class_num=1000, ): + super(CondenseNetV2, self).__init__() + self.stages = stages + self.growth = growth + self.in_channels = in_channels + self.class_num = class_num + self.last_se_reduction = last_se_reduction + assert len(self.stages) == len(self.growth) + self.progress = 0.0 + + self.init_stride = 2 + self.pool_size = 7 + + self.features = nn.Sequential() + # Initial nChannels should be 3 + self.num_features = 2 * self.growth[0] + # Dense-block 1 (224x224) + self.features.add_sublayer( + "init_conv", + nn.Conv2D( + in_channels, + self.num_features, + kernel_size=3, + stride=self.init_stride, + padding=1, + bias_attr=False, ), ) + for i in range(len(self.stages)): + activation = "HS" if i >= HS_start_block else "ReLU" + use_se = True if i >= SE_start_block else False + # Dense-block i + self.add_block(i, group_1x1, group_3x3, group_trans, bottleneck, + activation, use_se) + + self.fc = nn.Linear(self.num_features, fc_channel) + self.fc_act = HS() + + # Classifier layer + if class_num > 0: + self.classifier = nn.Linear(fc_channel, class_num) + self._initialize() + + def add_block(self, i, group_1x1, group_3x3, group_trans, bottleneck, + activation, use_se): + # Check if ith is the last one + last = i == len(self.stages) - 1 + block = _SFR_DenseBlock( + num_layers=self.stages[i], + in_channels=self.num_features, + growth_rate=self.growth[i], + group_1x1=group_1x1, + group_3x3=group_3x3, + group_trans=group_trans, + bottleneck=bottleneck, + activation=activation, + use_se=use_se, ) + self.features.add_sublayer("denseblock_%d" % (i + 1), block) + self.num_features += self.stages[i] * self.growth[i] + if not last: + trans = _Transition() + self.features.add_sublayer("transition_%d" % (i + 1), trans) + else: + self.features.add_sublayer("norm_last", + nn.BatchNorm2D(self.num_features)) + self.features.add_sublayer("relu_last", nn.ReLU()) + self.features.add_sublayer("pool_last", + nn.AvgPool2D(self.pool_size)) + # if useSE: + self.features.add_sublayer( + "se_last", + SELayer( + self.num_features, reduction=self.last_se_reduction)) + + def forward(self, x): + features = self.features(x) + out = features.reshape((features.shape[0], features.shape[1] * + features.shape[2] * features.shape[3])) + out = self.fc(out) + out = self.fc_act(out) + + if self.class_num > 0: + out = self.classifier(out) + + return out + + def _initialize(self): + # Initialize + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + nn.initializer.KaimingNormal()(m.weight) + elif isinstance(m, nn.BatchNorm2D): + nn.initializer.Constant(value=1.0)(m.weight) + nn.initializer.Constant(value=0.0)(m.bias) + + +def CondenseNetV2_A(**kwargs): + model = CondenseNetV2( + stages=[1, 1, 4, 6, 8], + growth=[8, 8, 16, 32, 64], + HS_start_block=2, + SE_start_block=3, + fc_channel=828, + group_1x1=8, + group_3x3=8, + group_trans=8, + bottleneck=4, + last_se_reduction=16, + **kwargs) + return model + + +def CondenseNetV2_B(**kwargs): + model = CondenseNetV2( + stages=[2, 4, 6, 8, 6], + growth=[6, 12, 24, 48, 96], + HS_start_block=2, + SE_start_block=3, + fc_channel=1024, + group_1x1=6, + group_3x3=6, + group_trans=6, + bottleneck=4, + last_se_reduction=16, + **kwargs) + return model + + +def CondenseNetV2_C(**kwargs): + model = CondenseNetV2( + stages=[4, 6, 8, 10, 8], + growth=[8, 16, 32, 64, 128], + HS_start_block=2, + SE_start_block=3, + fc_channel=1024, + group_1x1=8, + group_3x3=8, + group_trans=8, + bottleneck=4, + last_se_reduction=16, + **kwargs) + return model diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 83c20fb..7af2c02 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -34,9 +34,7 @@ from paddlers.utils.checkpoint import cls_pretrain_weights_dict from paddlers.transforms import Resize, decode_image from .base import BaseModel -__all__ = [ - "ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C", "CondenseNetV2_b" -] +__all__ = ["ResNet50_vd", "MobileNetV3", "HRNet", "CondenseNetV2"] class BaseClassifier(BaseModel): @@ -600,13 +598,13 @@ class ResNet50_vd(BaseClassifier): **params) -class MobileNetV3_small_x1_0(BaseClassifier): +class MobileNetV3(BaseClassifier): def __init__(self, num_classes=2, use_mixed_loss=False, losses=None, **params): - super(MobileNetV3_small_x1_0, self).__init__( + super(MobileNetV3, self).__init__( model_name='MobileNetV3_small_x1_0', num_classes=num_classes, use_mixed_loss=use_mixed_loss, @@ -614,13 +612,13 @@ class MobileNetV3_small_x1_0(BaseClassifier): **params) -class HRNet_W18_C(BaseClassifier): +class HRNet(BaseClassifier): def __init__(self, num_classes=2, use_mixed_loss=False, losses=None, **params): - super(HRNet_W18_C, self).__init__( + super(HRNet, self).__init__( model_name='HRNet_W18_C', num_classes=num_classes, use_mixed_loss=use_mixed_loss, @@ -628,15 +626,21 @@ class HRNet_W18_C(BaseClassifier): **params) -class CondenseNetV2_b(BaseClassifier): +class CondenseNetV2(BaseClassifier): def __init__(self, num_classes=2, use_mixed_loss=False, losses=None, + in_chnanels=3, + arch='A', **params): - super(CondenseNetV2_b, self).__init__( - model_name='CondenseNetV2_b', + if arch not in ('A', 'B', 'C'): + raise ValueError("{} is not a supported architecture.".format(arch)) + model_name = 'CondenseNetV2_' + arch + super(CondenseNetV2, self).__init__( + model_name=model_name, num_classes=num_classes, use_mixed_loss=use_mixed_loss, losses=losses, + in_channels=in_channels, **params) diff --git a/test_tipc/configs/clas/condensenetv2/condensenetv2_ucmerced.yaml b/test_tipc/configs/clas/condensenetv2/condensenetv2_ucmerced.yaml new file mode 100644 index 0000000..9808f00 --- /dev/null +++ b/test_tipc/configs/clas/condensenetv2/condensenetv2_ucmerced.yaml @@ -0,0 +1,10 @@ +# Configurations of CondenseNet V2 with UCMerced dataset + +_base_: ../_base_/ucmerced.yaml + +save_dir: ./test_tipc/output/clas/condensenetv2/ + +model: !Node + type: CondenseNetV2 + args: + num_classes: 21 \ No newline at end of file diff --git a/test_tipc/configs/clas/condensenetv2/train_infer_python.txt b/test_tipc/configs/clas/condensenetv2/train_infer_python.txt new file mode 100644 index 0000000..0e8832b --- /dev/null +++ b/test_tipc/configs/clas/condensenetv2/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:clas:condensenetv2 +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=16|lite_train_whole_infer=16|whole_train_whole_infer=16 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/clas/condensenetv2/condensenetv2_ucmerced.yaml|lite_train_whole_infer=./test_tipc/configs/clas/condensenetv2/condensenetv2_ucmerced.yaml|whole_train_whole_infer=./test_tipc/configs/clas/condensenetv2/condensenetv2_ucmerced.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train clas +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,3,256,256] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:condensenetv2 +null:null \ No newline at end of file diff --git a/test_tipc/configs/clas/hrnet/hrnet.yaml b/test_tipc/configs/clas/hrnet/hrnet.yaml deleted file mode 100644 index 4c9879f..0000000 --- a/test_tipc/configs/clas/hrnet/hrnet.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# Basic configurations of HRNet - -_base_: ../_base_/ucmerced.yaml - -save_dir: ./test_tipc/output/clas/hrnet/ - -model: !Node - type: HRNet_W18_C - args: - num_classes: 21 \ No newline at end of file diff --git a/test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml b/test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml index 3a09756..3a58807 100644 --- a/test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml +++ b/test_tipc/configs/clas/hrnet/hrnet_ucmerced.yaml @@ -5,6 +5,6 @@ _base_: ../_base_/ucmerced.yaml save_dir: ./test_tipc/output/clas/hrnet/ model: !Node - type: HRNet_W18_C + type: HRNet args: num_classes: 21 \ No newline at end of file diff --git a/tests/rs_models/test_clas_models.py b/tests/rs_models/test_clas_models.py index ab184fa..3c89d78 100644 --- a/tests/rs_models/test_clas_models.py +++ b/tests/rs_models/test_clas_models.py @@ -18,7 +18,7 @@ from rs_models.test_model import TestModel __all__ = [] -class TestCDModel(TestModel): +class TestClasModel(TestModel): DEFAULT_HW = (224, 224) def check_output(self, output, target): @@ -36,3 +36,36 @@ class TestCDModel(TestModel): def set_targets(self): self.targets = [[self.DEFAULT_BATCH_SIZE, spec.get('num_classes', 2)] for spec in self.specs] + + +class TestCondenseNetV2AModel(TestClasModel): + MODEL_CLASS = paddlers.rs_models.clas.CondenseNetV2_A + + def set_specs(self): + self.specs = [ + dict(in_channels=3, num_classes=2), + dict(in_channels=10, num_classes=2), + dict(in_channels=3, num_classes=100) + ] # yapf: disable + + +class TestCondenseNetV2BModel(TestClasModel): + MODEL_CLASS = paddlers.rs_models.clas.CondenseNetV2_B + + def set_specs(self): + self.specs = [ + dict(in_channels=3, num_classes=2), + dict(in_channels=10, num_classes=2), + dict(in_channels=3, num_classes=100) + ] # yapf: disable + + +class TestCondenseNetV2CModel(TestClasModel): + MODEL_CLASS = paddlers.rs_models.clas.CondenseNetV2_C + + def set_specs(self): + self.specs = [ + dict(in_channels=3, num_classes=2), + dict(in_channels=10, num_classes=2), + dict(in_channels=3, num_classes=100) + ] # yapf: disable diff --git a/tutorials/train/classification/condensenetv2.py b/tutorials/train/classification/condensenetv2.py new file mode 100644 index 0000000..62fd4f4 --- /dev/null +++ b/tutorials/train/classification/condensenetv2.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# 场景分类模型CondenseNet V2训练示例脚本 +# 执行此脚本前,请确认已正确安装PaddleRS库 + +import paddlers as pdrs +from paddlers import transforms as T + +# 数据集存放目录 +DATA_DIR = './data/ucmerced/' +# 训练集`file_list`文件路径 +TRAIN_FILE_LIST_PATH = './data/ucmerced/train.txt' +# 验证集`file_list`文件路径 +EVAL_FILE_LIST_PATH = './data/ucmerced/val.txt' +# 数据集类别信息文件路径 +LABEL_LIST_PATH = './data/ucmerced/labels.txt' +# 实验目录,保存输出的模型权重和结果 +EXP_DIR = './output/hrnet/' + +# 下载和解压UC Merced数据集 +pdrs.utils.download_and_decompress( + 'https://paddlers.bj.bcebos.com/datasets/ucmerced.zip', path='./data/') + +# 定义训练和验证时使用的数据变换(数据增强、预处理等) +# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 +# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md +train_transforms = T.Compose([ + # 读取影像 + T.DecodeImg(), + # 将影像缩放到256x256大小 + T.Resize(target_size=256), + # 以50%的概率实施随机水平翻转 + T.RandomHorizontalFlip(prob=0.5), + # 以50%的概率实施随机垂直翻转 + T.RandomVerticalFlip(prob=0.5), + # 将数据归一化到[-1,1] + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + T.ArrangeClassifier('train') +]) + +eval_transforms = T.Compose([ + T.DecodeImg(), + T.Resize(target_size=256), + # 验证阶段与训练阶段的数据归一化方式必须相同 + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + T.ArrangeClassifier('eval') +]) + +# 分别构建训练和验证所用的数据集 +train_dataset = pdrs.datasets.ClasDataset( + data_dir=DATA_DIR, + file_list=TRAIN_FILE_LIST_PATH, + label_list=LABEL_LIST_PATH, + transforms=train_transforms, + num_workers=0, + shuffle=True) + +eval_dataset = pdrs.datasets.ClasDataset( + data_dir=DATA_DIR, + file_list=EVAL_FILE_LIST_PATH, + label_list=LABEL_LIST_PATH, + transforms=eval_transforms, + num_workers=0, + shuffle=False) + +# 构建CondenseNet V2模型 +# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md +# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/classifier.py +model = pdrs.tasks.clas.CondenseNetV2(num_classes=len(train_dataset.labels)) + +# 执行模型训练 +model.train( + num_epochs=2, + train_dataset=train_dataset, + train_batch_size=16, + eval_dataset=eval_dataset, + save_interval_epochs=1, + # 每多少次迭代记录一次日志 + log_interval_steps=50, + save_dir=EXP_DIR, + # 初始学习率大小 + learning_rate=0.01, + # 是否使用early stopping策略,当精度不再改善时提前终止训练 + early_stop=False, + # 是否启用VisualDL日志功能 + use_vdl=True, + # 指定从某个检查点继续训练 + resume_checkpoint=None)