parent
14863747a7
commit
db802a389b
13 changed files with 587 additions and 4 deletions
@ -0,0 +1,15 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
from .farseg import FarSeg |
@ -0,0 +1,15 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
from .farseg import FarSeg |
@ -0,0 +1,171 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import math |
||||||
|
import paddle.nn as nn |
||||||
|
import paddle.nn.functional as F |
||||||
|
from paddle.vision.models import resnet50 |
||||||
|
from .fpn import FPN |
||||||
|
from ppseg.rs_models.utils import Identity |
||||||
|
|
||||||
|
|
||||||
|
class SceneRelation(nn.Layer): |
||||||
|
def __init__(self, |
||||||
|
in_channels, |
||||||
|
channel_list, |
||||||
|
out_channels, |
||||||
|
scale_aware_proj=True): |
||||||
|
super(SceneRelation, self).__init__() |
||||||
|
self.scale_aware_proj = scale_aware_proj |
||||||
|
if scale_aware_proj: |
||||||
|
self.scene_encoder = nn.LayerList([nn.Sequential( |
||||||
|
nn.Conv2D(in_channels, out_channels, 1), |
||||||
|
nn.ReLU(), |
||||||
|
nn.Conv2D(out_channels, out_channels, 1)) for _ in range(len(channel_list)) |
||||||
|
]) |
||||||
|
else: |
||||||
|
# 2mlp |
||||||
|
self.scene_encoder = nn.Sequential( |
||||||
|
nn.Conv2D(in_channels, out_channels, 1), |
||||||
|
nn.ReLU(), |
||||||
|
nn.Conv2D(out_channels, out_channels, 1), |
||||||
|
) |
||||||
|
self.content_encoders = nn.LayerList() |
||||||
|
self.feature_reencoders = nn.LayerList() |
||||||
|
for c in channel_list: |
||||||
|
self.content_encoders.append(nn.Sequential( |
||||||
|
nn.Conv2D(c, out_channels, 1), |
||||||
|
nn.BatchNorm2D(out_channels), |
||||||
|
nn.ReLU() |
||||||
|
)) |
||||||
|
self.feature_reencoders.append(nn.Sequential( |
||||||
|
nn.Conv2D(c, out_channels, 1), |
||||||
|
nn.BatchNorm2D(out_channels), |
||||||
|
nn.ReLU() |
||||||
|
)) |
||||||
|
self.normalizer = nn.Sigmoid() |
||||||
|
|
||||||
|
def forward(self, scene_feature, features: list): |
||||||
|
content_feats = [c_en(p_feat) for c_en, p_feat in zip(self.content_encoders, features)] |
||||||
|
if self.scale_aware_proj: |
||||||
|
scene_feats = [op(scene_feature) for op in self.scene_encoder] |
||||||
|
relations = [self.normalizer((sf * cf).sum(axis=1, keepdim=True)) |
||||||
|
for sf, cf in zip(scene_feats, content_feats)] |
||||||
|
else: |
||||||
|
scene_feat = self.scene_encoder(scene_feature) |
||||||
|
relations = [self.normalizer((scene_feat * cf).sum(axis=1, keepdim=True)) |
||||||
|
for cf in content_feats] |
||||||
|
p_feats = [op(p_feat) for op, p_feat in zip(self.feature_reencoders, features)] |
||||||
|
refined_feats = [r * p for r, p in zip(relations, p_feats)] |
||||||
|
return refined_feats |
||||||
|
|
||||||
|
|
||||||
|
class AssymetricDecoder(nn.Layer): |
||||||
|
def __init__(self, |
||||||
|
in_channels, |
||||||
|
out_channels, |
||||||
|
in_feat_output_strides=(4, 8, 16, 32), |
||||||
|
out_feat_output_stride=4, |
||||||
|
norm_fn=nn.BatchNorm2D, |
||||||
|
num_groups_gn=None): |
||||||
|
super(AssymetricDecoder, self).__init__() |
||||||
|
if norm_fn == nn.BatchNorm2D: |
||||||
|
norm_fn_args = dict(num_features=out_channels) |
||||||
|
elif norm_fn == nn.GroupNorm: |
||||||
|
if num_groups_gn is None: |
||||||
|
raise ValueError('When norm_fn is nn.GroupNorm, num_groups_gn is needed.') |
||||||
|
norm_fn_args = dict(num_groups=num_groups_gn, num_channels=out_channels) |
||||||
|
else: |
||||||
|
raise ValueError('Type of {} is not support.'.format(type(norm_fn))) |
||||||
|
self.blocks = nn.LayerList() |
||||||
|
for in_feat_os in in_feat_output_strides: |
||||||
|
num_upsample = int(math.log2(int(in_feat_os))) - int(math.log2(int(out_feat_output_stride))) |
||||||
|
num_layers = num_upsample if num_upsample != 0 else 1 |
||||||
|
self.blocks.append(nn.Sequential(*[ |
||||||
|
nn.Sequential( |
||||||
|
nn.Conv2D(in_channels if idx == 0 else out_channels, out_channels, 3, 1, 1, bias_attr=False), |
||||||
|
norm_fn(**norm_fn_args) if norm_fn is not None else Identity(), |
||||||
|
nn.ReLU(), |
||||||
|
nn.UpsamplingBilinear2D(scale_factor=2) if num_upsample != 0 else Identity(), |
||||||
|
) for idx in range(num_layers) |
||||||
|
])) |
||||||
|
|
||||||
|
def forward(self, feat_list: list): |
||||||
|
inner_feat_list = [] |
||||||
|
for idx, block in enumerate(self.blocks): |
||||||
|
decoder_feat = block(feat_list[idx]) |
||||||
|
inner_feat_list.append(decoder_feat) |
||||||
|
out_feat = sum(inner_feat_list) / 4. |
||||||
|
return out_feat |
||||||
|
|
||||||
|
|
||||||
|
class ResNet50Encoder(nn.Layer): |
||||||
|
def __init__(self, pretrained=True): |
||||||
|
super(ResNet50Encoder, self).__init__() |
||||||
|
self.resnet = resnet50(pretrained=pretrained) |
||||||
|
|
||||||
|
def forward(self, inputs): |
||||||
|
x = inputs |
||||||
|
x = self.resnet.conv1(x) |
||||||
|
x = self.resnet.bn1(x) |
||||||
|
x = self.resnet.relu(x) |
||||||
|
x = self.resnet.maxpool(x) |
||||||
|
c2 = self.resnet.layer1(x) |
||||||
|
c3 = self.resnet.layer2(c2) |
||||||
|
c4 = self.resnet.layer3(c3) |
||||||
|
c5 = self.resnet.layer4(c4) |
||||||
|
return [c2, c3, c4, c5] |
||||||
|
|
||||||
|
|
||||||
|
class FarSeg(nn.Layer): |
||||||
|
''' |
||||||
|
The FarSeg implementation based on PaddlePaddle. |
||||||
|
|
||||||
|
The original article refers to |
||||||
|
Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution Remote Sensing Imagery" |
||||||
|
(https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf) |
||||||
|
''' |
||||||
|
def __init__(self, |
||||||
|
num_classes=16, |
||||||
|
fpn_ch_list=(256, 512, 1024, 2048), |
||||||
|
mid_ch=256, |
||||||
|
out_ch=128, |
||||||
|
sr_ch_list=(256, 256, 256, 256), |
||||||
|
encoder_pretrained=True): |
||||||
|
super(FarSeg, self).__init__() |
||||||
|
self.en = ResNet50Encoder(encoder_pretrained) |
||||||
|
self.fpn = FPN(in_channels_list=fpn_ch_list, |
||||||
|
out_channels=mid_ch) |
||||||
|
self.decoder = AssymetricDecoder(in_channels=mid_ch, out_channels=out_ch) |
||||||
|
self.cls_pred_conv = nn.Conv2D(out_ch, num_classes, 1) |
||||||
|
self.upsample4x_op = nn.UpsamplingBilinear2D(scale_factor=4) |
||||||
|
self.scene_relation = True if sr_ch_list is not None else False |
||||||
|
if self.scene_relation: |
||||||
|
self.gap = nn.AdaptiveAvgPool2D(1) |
||||||
|
self.sr = SceneRelation(fpn_ch_list[-1], sr_ch_list, mid_ch) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
feat_list = self.en(x) |
||||||
|
fpn_feat_list = self.fpn(feat_list) |
||||||
|
if self.scene_relation: |
||||||
|
c5 = feat_list[-1] |
||||||
|
c6 = self.gap(c5) |
||||||
|
refined_fpn_feat_list = self.sr(c6, fpn_feat_list) |
||||||
|
else: |
||||||
|
refined_fpn_feat_list = fpn_feat_list |
||||||
|
final_feat = self.decoder(refined_fpn_feat_list) |
||||||
|
cls_pred = self.cls_pred_conv(final_feat) |
||||||
|
cls_pred = self.upsample4x_op(cls_pred) |
||||||
|
cls_pred = F.softmax(cls_pred, axis=1) |
||||||
|
return [cls_pred] |
@ -0,0 +1,97 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
from paddle import nn |
||||||
|
import paddle.nn.functional as F |
||||||
|
from ppseg.rs_models.utils import ( |
||||||
|
ConvReLU, kaiming_normal_init, constant_init |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
class FPN(nn.Layer): |
||||||
|
""" |
||||||
|
Module that adds FPN on top of a list of feature maps. |
||||||
|
The feature maps are currently supposed to be in increasing depth |
||||||
|
order, and must be consecutive |
||||||
|
""" |
||||||
|
def __init__(self, |
||||||
|
in_channels_list, |
||||||
|
out_channels, |
||||||
|
conv_block=ConvReLU, |
||||||
|
top_blocks=None |
||||||
|
): |
||||||
|
super(FPN, self).__init__() |
||||||
|
self.inner_blocks = [] |
||||||
|
self.layer_blocks = [] |
||||||
|
for idx, in_channels in enumerate(in_channels_list, 1): |
||||||
|
inner_block = "fpn_inner{}".format(idx) |
||||||
|
layer_block = "fpn_layer{}".format(idx) |
||||||
|
if in_channels == 0: |
||||||
|
continue |
||||||
|
inner_block_module = conv_block(in_channels, out_channels, 1) |
||||||
|
layer_block_module = conv_block(out_channels, out_channels, 3, 1) |
||||||
|
self.add_sublayer(inner_block, inner_block_module) |
||||||
|
self.add_sublayer(layer_block, layer_block_module) |
||||||
|
for module in [inner_block_module, layer_block_module]: |
||||||
|
for m in module.sublayers(): |
||||||
|
if isinstance(m, nn.Conv2D): |
||||||
|
kaiming_normal_init(m.weight) |
||||||
|
self.inner_blocks.append(inner_block) |
||||||
|
self.layer_blocks.append(layer_block) |
||||||
|
self.top_blocks = top_blocks |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) |
||||||
|
results = [getattr(self, self.layer_blocks[-1])(last_inner)] |
||||||
|
for feature, inner_block, layer_block in zip( |
||||||
|
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]): |
||||||
|
if not inner_block: |
||||||
|
continue |
||||||
|
inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") |
||||||
|
inner_lateral = getattr(self, inner_block)(feature) |
||||||
|
last_inner = inner_lateral + inner_top_down |
||||||
|
results.insert(0, getattr(self, layer_block)(last_inner)) |
||||||
|
if isinstance(self.top_blocks, LastLevelP6P7): |
||||||
|
last_results = self.top_blocks(x[-1], results[-1]) |
||||||
|
results.extend(last_results) |
||||||
|
elif isinstance(self.top_blocks, LastLevelMaxPool): |
||||||
|
last_results = self.top_blocks(results[-1]) |
||||||
|
results.extend(last_results) |
||||||
|
return tuple(results) |
||||||
|
|
||||||
|
|
||||||
|
class LastLevelMaxPool(nn.Layer): |
||||||
|
def forward(self, x): |
||||||
|
return [F.max_pool2d(x, 1, 2, 0)] |
||||||
|
|
||||||
|
|
||||||
|
class LastLevelP6P7(nn.Layer): |
||||||
|
""" |
||||||
|
This module is used in RetinaNet to generate extra layers, P6 and P7. |
||||||
|
""" |
||||||
|
def __init__(self, in_channels, out_channels): |
||||||
|
super(LastLevelP6P7, self).__init__() |
||||||
|
self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1) |
||||||
|
self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1) |
||||||
|
for module in [self.p6, self.p7]: |
||||||
|
for m in module.sublayers(): |
||||||
|
kaiming_normal_init(m.weight) |
||||||
|
constant_init(m.bias, value=0) |
||||||
|
self.use_P5 = in_channels == out_channels |
||||||
|
|
||||||
|
def forward(self, c5, p5): |
||||||
|
x = p5 if self.use_P5 else c5 |
||||||
|
p6 = self.p6(x) |
||||||
|
p7 = self.p7(F.relu(p6)) |
||||||
|
return [p6, p7] |
@ -0,0 +1,17 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
from .torch_nn import * |
||||||
|
from .param_init import * |
||||||
|
from .layers_lib import * |
@ -0,0 +1,139 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import paddle |
||||||
|
import paddle.nn as nn |
||||||
|
import paddle.nn.functional as F |
||||||
|
|
||||||
|
|
||||||
|
class ConvBNReLU(nn.Layer): |
||||||
|
def __init__(self, |
||||||
|
in_channels, |
||||||
|
out_channels, |
||||||
|
kernel_size, |
||||||
|
padding='same', |
||||||
|
**kwargs): |
||||||
|
super().__init__() |
||||||
|
self._conv = nn.Conv2D( |
||||||
|
in_channels, out_channels, kernel_size, padding=padding, **kwargs) |
||||||
|
if 'data_format' in kwargs: |
||||||
|
data_format = kwargs['data_format'] |
||||||
|
else: |
||||||
|
data_format = 'NCHW' |
||||||
|
self._batch_norm = nn.BatchNorm2D(out_channels, data_format=data_format) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
x = self._conv(x) |
||||||
|
x = self._batch_norm(x) |
||||||
|
x = F.relu(x) |
||||||
|
return x |
||||||
|
|
||||||
|
|
||||||
|
class ConvBN(nn.Layer): |
||||||
|
def __init__(self, |
||||||
|
in_channels, |
||||||
|
out_channels, |
||||||
|
kernel_size, |
||||||
|
padding='same', |
||||||
|
**kwargs): |
||||||
|
super().__init__() |
||||||
|
self._conv = nn.Conv2D( |
||||||
|
in_channels, out_channels, kernel_size, padding=padding, **kwargs) |
||||||
|
if 'data_format' in kwargs: |
||||||
|
data_format = kwargs['data_format'] |
||||||
|
else: |
||||||
|
data_format = 'NCHW' |
||||||
|
self._batch_norm = nn.BatchNorm2D(out_channels, data_format=data_format) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
x = self._conv(x) |
||||||
|
x = self._batch_norm(x) |
||||||
|
return x |
||||||
|
|
||||||
|
|
||||||
|
class ConvReLU(nn.Layer): |
||||||
|
def __init__(self, |
||||||
|
in_channels, |
||||||
|
out_channels, |
||||||
|
kernel_size, |
||||||
|
padding='same', |
||||||
|
**kwargs): |
||||||
|
super().__init__() |
||||||
|
self._conv = nn.Conv2D( |
||||||
|
in_channels, out_channels, kernel_size, padding=padding, **kwargs) |
||||||
|
if 'data_format' in kwargs: |
||||||
|
data_format = kwargs['data_format'] |
||||||
|
else: |
||||||
|
data_format = 'NCHW' |
||||||
|
self._relu = nn.ReLU() |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
x = self._conv(x) |
||||||
|
x = self._relu(x) |
||||||
|
return x |
||||||
|
|
||||||
|
|
||||||
|
class Add(nn.Layer): |
||||||
|
def __init__(self): |
||||||
|
super().__init__() |
||||||
|
|
||||||
|
def forward(self, x, y, name=None): |
||||||
|
return paddle.add(x, y, name) |
||||||
|
|
||||||
|
|
||||||
|
class Activation(nn.Layer): |
||||||
|
""" |
||||||
|
The wrapper of activations. |
||||||
|
Args: |
||||||
|
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu', |
||||||
|
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', |
||||||
|
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', |
||||||
|
'hsigmoid']. Default: None, means identical transformation. |
||||||
|
Returns: |
||||||
|
A callable object of Activation. |
||||||
|
Raises: |
||||||
|
KeyError: When parameter `act` is not in the optional range. |
||||||
|
Examples: |
||||||
|
from paddleseg.models.common.activation import Activation |
||||||
|
relu = Activation("relu") |
||||||
|
print(relu) |
||||||
|
# <class 'paddle.nn.layer.activation.ReLU'> |
||||||
|
sigmoid = Activation("sigmoid") |
||||||
|
print(sigmoid) |
||||||
|
# <class 'paddle.nn.layer.activation.Sigmoid'> |
||||||
|
not_exit_one = Activation("not_exit_one") |
||||||
|
# KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink', |
||||||
|
# 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax', |
||||||
|
# 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])" |
||||||
|
""" |
||||||
|
def __init__(self, act=None): |
||||||
|
super(Activation, self).__init__() |
||||||
|
self._act = act |
||||||
|
upper_act_names = nn.layer.activation.__dict__.keys() |
||||||
|
lower_act_names = [act.lower() for act in upper_act_names] |
||||||
|
act_dict = dict(zip(lower_act_names, upper_act_names)) |
||||||
|
if act is not None: |
||||||
|
if act in act_dict.keys(): |
||||||
|
act_name = act_dict[act] |
||||||
|
self.act_func = eval( |
||||||
|
"nn.layer.activation.{}()".format(act_name)) |
||||||
|
else: |
||||||
|
raise KeyError("{} does not exist in the current {}".format( |
||||||
|
act, act_dict.keys())) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
if self._act is not None: |
||||||
|
return self.act_func(x) |
||||||
|
else: |
||||||
|
return x |
@ -0,0 +1,30 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import paddle.nn as nn |
||||||
|
|
||||||
|
|
||||||
|
def constant_init(param, **kwargs): |
||||||
|
initializer = nn.initializer.Constant(**kwargs) |
||||||
|
initializer(param, param.block) |
||||||
|
|
||||||
|
|
||||||
|
def normal_init(param, **kwargs): |
||||||
|
initializer = nn.initializer.Normal(**kwargs) |
||||||
|
initializer(param, param.block) |
||||||
|
|
||||||
|
|
||||||
|
def kaiming_normal_init(param, **kwargs): |
||||||
|
initializer = nn.initializer.KaimingNormal(**kwargs) |
||||||
|
initializer(param, param.block) |
@ -0,0 +1,23 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import paddle.nn as nn |
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Layer): |
||||||
|
def __init__(self, *args, **kwargs): |
||||||
|
super(Identity, self).__init__() |
||||||
|
|
||||||
|
def forward(self, input): |
||||||
|
return input |
@ -0,0 +1,59 @@ |
|||||||
|
import sys |
||||||
|
|
||||||
|
sys.path.append("E:/dataFiles/github/PaddleRS") |
||||||
|
|
||||||
|
import paddlers as pdrs |
||||||
|
from paddlers import transforms as T |
||||||
|
|
||||||
|
# 下载和解压视盘分割数据集 |
||||||
|
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz' |
||||||
|
pdrs.utils.download_and_decompress(optic_dataset, path='./') |
||||||
|
|
||||||
|
# 定义训练和验证时的transforms |
||||||
|
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md |
||||||
|
train_transforms = T.Compose([ |
||||||
|
T.Resize(target_size=512), |
||||||
|
T.RandomHorizontalFlip(), |
||||||
|
T.Normalize( |
||||||
|
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||||
|
]) |
||||||
|
|
||||||
|
eval_transforms = T.Compose([ |
||||||
|
T.Resize(target_size=512), |
||||||
|
T.Normalize( |
||||||
|
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||||
|
]) |
||||||
|
|
||||||
|
# 定义训练和验证所用的数据集 |
||||||
|
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md |
||||||
|
train_dataset = pdrs.datasets.SegDataset( |
||||||
|
data_dir='optic_disc_seg', |
||||||
|
file_list='optic_disc_seg/train_list.txt', |
||||||
|
label_list='optic_disc_seg/labels.txt', |
||||||
|
transforms=train_transforms, |
||||||
|
num_workers=0, |
||||||
|
shuffle=True) |
||||||
|
|
||||||
|
eval_dataset = pdrs.datasets.SegDataset( |
||||||
|
data_dir='optic_disc_seg', |
||||||
|
file_list='optic_disc_seg/val_list.txt', |
||||||
|
label_list='optic_disc_seg/labels.txt', |
||||||
|
transforms=eval_transforms, |
||||||
|
num_workers=0, |
||||||
|
shuffle=False) |
||||||
|
|
||||||
|
# 初始化模型,并进行训练 |
||||||
|
# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md |
||||||
|
num_classes = len(train_dataset.labels) |
||||||
|
model = pdrs.tasks.FarSeg(num_classes=num_classes) |
||||||
|
|
||||||
|
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md |
||||||
|
# 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md |
||||||
|
model.train( |
||||||
|
num_epochs=10, |
||||||
|
train_dataset=train_dataset, |
||||||
|
train_batch_size=4, |
||||||
|
eval_dataset=eval_dataset, |
||||||
|
learning_rate=0.01, |
||||||
|
pretrain_weights=None, |
||||||
|
save_dir='output/farseg') |
Loading…
Reference in new issue