parent
14863747a7
commit
db802a389b
13 changed files with 587 additions and 4 deletions
@ -0,0 +1,15 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from .farseg import FarSeg |
@ -0,0 +1,15 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from .farseg import FarSeg |
@ -0,0 +1,171 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import math |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
from paddle.vision.models import resnet50 |
||||
from .fpn import FPN |
||||
from ppseg.rs_models.utils import Identity |
||||
|
||||
|
||||
class SceneRelation(nn.Layer): |
||||
def __init__(self, |
||||
in_channels, |
||||
channel_list, |
||||
out_channels, |
||||
scale_aware_proj=True): |
||||
super(SceneRelation, self).__init__() |
||||
self.scale_aware_proj = scale_aware_proj |
||||
if scale_aware_proj: |
||||
self.scene_encoder = nn.LayerList([nn.Sequential( |
||||
nn.Conv2D(in_channels, out_channels, 1), |
||||
nn.ReLU(), |
||||
nn.Conv2D(out_channels, out_channels, 1)) for _ in range(len(channel_list)) |
||||
]) |
||||
else: |
||||
# 2mlp |
||||
self.scene_encoder = nn.Sequential( |
||||
nn.Conv2D(in_channels, out_channels, 1), |
||||
nn.ReLU(), |
||||
nn.Conv2D(out_channels, out_channels, 1), |
||||
) |
||||
self.content_encoders = nn.LayerList() |
||||
self.feature_reencoders = nn.LayerList() |
||||
for c in channel_list: |
||||
self.content_encoders.append(nn.Sequential( |
||||
nn.Conv2D(c, out_channels, 1), |
||||
nn.BatchNorm2D(out_channels), |
||||
nn.ReLU() |
||||
)) |
||||
self.feature_reencoders.append(nn.Sequential( |
||||
nn.Conv2D(c, out_channels, 1), |
||||
nn.BatchNorm2D(out_channels), |
||||
nn.ReLU() |
||||
)) |
||||
self.normalizer = nn.Sigmoid() |
||||
|
||||
def forward(self, scene_feature, features: list): |
||||
content_feats = [c_en(p_feat) for c_en, p_feat in zip(self.content_encoders, features)] |
||||
if self.scale_aware_proj: |
||||
scene_feats = [op(scene_feature) for op in self.scene_encoder] |
||||
relations = [self.normalizer((sf * cf).sum(axis=1, keepdim=True)) |
||||
for sf, cf in zip(scene_feats, content_feats)] |
||||
else: |
||||
scene_feat = self.scene_encoder(scene_feature) |
||||
relations = [self.normalizer((scene_feat * cf).sum(axis=1, keepdim=True)) |
||||
for cf in content_feats] |
||||
p_feats = [op(p_feat) for op, p_feat in zip(self.feature_reencoders, features)] |
||||
refined_feats = [r * p for r, p in zip(relations, p_feats)] |
||||
return refined_feats |
||||
|
||||
|
||||
class AssymetricDecoder(nn.Layer): |
||||
def __init__(self, |
||||
in_channels, |
||||
out_channels, |
||||
in_feat_output_strides=(4, 8, 16, 32), |
||||
out_feat_output_stride=4, |
||||
norm_fn=nn.BatchNorm2D, |
||||
num_groups_gn=None): |
||||
super(AssymetricDecoder, self).__init__() |
||||
if norm_fn == nn.BatchNorm2D: |
||||
norm_fn_args = dict(num_features=out_channels) |
||||
elif norm_fn == nn.GroupNorm: |
||||
if num_groups_gn is None: |
||||
raise ValueError('When norm_fn is nn.GroupNorm, num_groups_gn is needed.') |
||||
norm_fn_args = dict(num_groups=num_groups_gn, num_channels=out_channels) |
||||
else: |
||||
raise ValueError('Type of {} is not support.'.format(type(norm_fn))) |
||||
self.blocks = nn.LayerList() |
||||
for in_feat_os in in_feat_output_strides: |
||||
num_upsample = int(math.log2(int(in_feat_os))) - int(math.log2(int(out_feat_output_stride))) |
||||
num_layers = num_upsample if num_upsample != 0 else 1 |
||||
self.blocks.append(nn.Sequential(*[ |
||||
nn.Sequential( |
||||
nn.Conv2D(in_channels if idx == 0 else out_channels, out_channels, 3, 1, 1, bias_attr=False), |
||||
norm_fn(**norm_fn_args) if norm_fn is not None else Identity(), |
||||
nn.ReLU(), |
||||
nn.UpsamplingBilinear2D(scale_factor=2) if num_upsample != 0 else Identity(), |
||||
) for idx in range(num_layers) |
||||
])) |
||||
|
||||
def forward(self, feat_list: list): |
||||
inner_feat_list = [] |
||||
for idx, block in enumerate(self.blocks): |
||||
decoder_feat = block(feat_list[idx]) |
||||
inner_feat_list.append(decoder_feat) |
||||
out_feat = sum(inner_feat_list) / 4. |
||||
return out_feat |
||||
|
||||
|
||||
class ResNet50Encoder(nn.Layer): |
||||
def __init__(self, pretrained=True): |
||||
super(ResNet50Encoder, self).__init__() |
||||
self.resnet = resnet50(pretrained=pretrained) |
||||
|
||||
def forward(self, inputs): |
||||
x = inputs |
||||
x = self.resnet.conv1(x) |
||||
x = self.resnet.bn1(x) |
||||
x = self.resnet.relu(x) |
||||
x = self.resnet.maxpool(x) |
||||
c2 = self.resnet.layer1(x) |
||||
c3 = self.resnet.layer2(c2) |
||||
c4 = self.resnet.layer3(c3) |
||||
c5 = self.resnet.layer4(c4) |
||||
return [c2, c3, c4, c5] |
||||
|
||||
|
||||
class FarSeg(nn.Layer): |
||||
''' |
||||
The FarSeg implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution Remote Sensing Imagery" |
||||
(https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf) |
||||
''' |
||||
def __init__(self, |
||||
num_classes=16, |
||||
fpn_ch_list=(256, 512, 1024, 2048), |
||||
mid_ch=256, |
||||
out_ch=128, |
||||
sr_ch_list=(256, 256, 256, 256), |
||||
encoder_pretrained=True): |
||||
super(FarSeg, self).__init__() |
||||
self.en = ResNet50Encoder(encoder_pretrained) |
||||
self.fpn = FPN(in_channels_list=fpn_ch_list, |
||||
out_channels=mid_ch) |
||||
self.decoder = AssymetricDecoder(in_channels=mid_ch, out_channels=out_ch) |
||||
self.cls_pred_conv = nn.Conv2D(out_ch, num_classes, 1) |
||||
self.upsample4x_op = nn.UpsamplingBilinear2D(scale_factor=4) |
||||
self.scene_relation = True if sr_ch_list is not None else False |
||||
if self.scene_relation: |
||||
self.gap = nn.AdaptiveAvgPool2D(1) |
||||
self.sr = SceneRelation(fpn_ch_list[-1], sr_ch_list, mid_ch) |
||||
|
||||
def forward(self, x): |
||||
feat_list = self.en(x) |
||||
fpn_feat_list = self.fpn(feat_list) |
||||
if self.scene_relation: |
||||
c5 = feat_list[-1] |
||||
c6 = self.gap(c5) |
||||
refined_fpn_feat_list = self.sr(c6, fpn_feat_list) |
||||
else: |
||||
refined_fpn_feat_list = fpn_feat_list |
||||
final_feat = self.decoder(refined_fpn_feat_list) |
||||
cls_pred = self.cls_pred_conv(final_feat) |
||||
cls_pred = self.upsample4x_op(cls_pred) |
||||
cls_pred = F.softmax(cls_pred, axis=1) |
||||
return [cls_pred] |
@ -0,0 +1,97 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from paddle import nn |
||||
import paddle.nn.functional as F |
||||
from ppseg.rs_models.utils import ( |
||||
ConvReLU, kaiming_normal_init, constant_init |
||||
) |
||||
|
||||
|
||||
class FPN(nn.Layer): |
||||
""" |
||||
Module that adds FPN on top of a list of feature maps. |
||||
The feature maps are currently supposed to be in increasing depth |
||||
order, and must be consecutive |
||||
""" |
||||
def __init__(self, |
||||
in_channels_list, |
||||
out_channels, |
||||
conv_block=ConvReLU, |
||||
top_blocks=None |
||||
): |
||||
super(FPN, self).__init__() |
||||
self.inner_blocks = [] |
||||
self.layer_blocks = [] |
||||
for idx, in_channels in enumerate(in_channels_list, 1): |
||||
inner_block = "fpn_inner{}".format(idx) |
||||
layer_block = "fpn_layer{}".format(idx) |
||||
if in_channels == 0: |
||||
continue |
||||
inner_block_module = conv_block(in_channels, out_channels, 1) |
||||
layer_block_module = conv_block(out_channels, out_channels, 3, 1) |
||||
self.add_sublayer(inner_block, inner_block_module) |
||||
self.add_sublayer(layer_block, layer_block_module) |
||||
for module in [inner_block_module, layer_block_module]: |
||||
for m in module.sublayers(): |
||||
if isinstance(m, nn.Conv2D): |
||||
kaiming_normal_init(m.weight) |
||||
self.inner_blocks.append(inner_block) |
||||
self.layer_blocks.append(layer_block) |
||||
self.top_blocks = top_blocks |
||||
|
||||
def forward(self, x): |
||||
last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) |
||||
results = [getattr(self, self.layer_blocks[-1])(last_inner)] |
||||
for feature, inner_block, layer_block in zip( |
||||
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]): |
||||
if not inner_block: |
||||
continue |
||||
inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") |
||||
inner_lateral = getattr(self, inner_block)(feature) |
||||
last_inner = inner_lateral + inner_top_down |
||||
results.insert(0, getattr(self, layer_block)(last_inner)) |
||||
if isinstance(self.top_blocks, LastLevelP6P7): |
||||
last_results = self.top_blocks(x[-1], results[-1]) |
||||
results.extend(last_results) |
||||
elif isinstance(self.top_blocks, LastLevelMaxPool): |
||||
last_results = self.top_blocks(results[-1]) |
||||
results.extend(last_results) |
||||
return tuple(results) |
||||
|
||||
|
||||
class LastLevelMaxPool(nn.Layer): |
||||
def forward(self, x): |
||||
return [F.max_pool2d(x, 1, 2, 0)] |
||||
|
||||
|
||||
class LastLevelP6P7(nn.Layer): |
||||
""" |
||||
This module is used in RetinaNet to generate extra layers, P6 and P7. |
||||
""" |
||||
def __init__(self, in_channels, out_channels): |
||||
super(LastLevelP6P7, self).__init__() |
||||
self.p6 = nn.Conv2D(in_channels, out_channels, 3, 2, 1) |
||||
self.p7 = nn.Conv2D(out_channels, out_channels, 3, 2, 1) |
||||
for module in [self.p6, self.p7]: |
||||
for m in module.sublayers(): |
||||
kaiming_normal_init(m.weight) |
||||
constant_init(m.bias, value=0) |
||||
self.use_P5 = in_channels == out_channels |
||||
|
||||
def forward(self, c5, p5): |
||||
x = p5 if self.use_P5 else c5 |
||||
p6 = self.p6(x) |
||||
p7 = self.p7(F.relu(p6)) |
||||
return [p6, p7] |
@ -0,0 +1,17 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from .torch_nn import * |
||||
from .param_init import * |
||||
from .layers_lib import * |
@ -0,0 +1,139 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
class ConvBNReLU(nn.Layer): |
||||
def __init__(self, |
||||
in_channels, |
||||
out_channels, |
||||
kernel_size, |
||||
padding='same', |
||||
**kwargs): |
||||
super().__init__() |
||||
self._conv = nn.Conv2D( |
||||
in_channels, out_channels, kernel_size, padding=padding, **kwargs) |
||||
if 'data_format' in kwargs: |
||||
data_format = kwargs['data_format'] |
||||
else: |
||||
data_format = 'NCHW' |
||||
self._batch_norm = nn.BatchNorm2D(out_channels, data_format=data_format) |
||||
|
||||
def forward(self, x): |
||||
x = self._conv(x) |
||||
x = self._batch_norm(x) |
||||
x = F.relu(x) |
||||
return x |
||||
|
||||
|
||||
class ConvBN(nn.Layer): |
||||
def __init__(self, |
||||
in_channels, |
||||
out_channels, |
||||
kernel_size, |
||||
padding='same', |
||||
**kwargs): |
||||
super().__init__() |
||||
self._conv = nn.Conv2D( |
||||
in_channels, out_channels, kernel_size, padding=padding, **kwargs) |
||||
if 'data_format' in kwargs: |
||||
data_format = kwargs['data_format'] |
||||
else: |
||||
data_format = 'NCHW' |
||||
self._batch_norm = nn.BatchNorm2D(out_channels, data_format=data_format) |
||||
|
||||
def forward(self, x): |
||||
x = self._conv(x) |
||||
x = self._batch_norm(x) |
||||
return x |
||||
|
||||
|
||||
class ConvReLU(nn.Layer): |
||||
def __init__(self, |
||||
in_channels, |
||||
out_channels, |
||||
kernel_size, |
||||
padding='same', |
||||
**kwargs): |
||||
super().__init__() |
||||
self._conv = nn.Conv2D( |
||||
in_channels, out_channels, kernel_size, padding=padding, **kwargs) |
||||
if 'data_format' in kwargs: |
||||
data_format = kwargs['data_format'] |
||||
else: |
||||
data_format = 'NCHW' |
||||
self._relu = nn.ReLU() |
||||
|
||||
def forward(self, x): |
||||
x = self._conv(x) |
||||
x = self._relu(x) |
||||
return x |
||||
|
||||
|
||||
class Add(nn.Layer): |
||||
def __init__(self): |
||||
super().__init__() |
||||
|
||||
def forward(self, x, y, name=None): |
||||
return paddle.add(x, y, name) |
||||
|
||||
|
||||
class Activation(nn.Layer): |
||||
""" |
||||
The wrapper of activations. |
||||
Args: |
||||
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu', |
||||
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', |
||||
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', |
||||
'hsigmoid']. Default: None, means identical transformation. |
||||
Returns: |
||||
A callable object of Activation. |
||||
Raises: |
||||
KeyError: When parameter `act` is not in the optional range. |
||||
Examples: |
||||
from paddleseg.models.common.activation import Activation |
||||
relu = Activation("relu") |
||||
print(relu) |
||||
# <class 'paddle.nn.layer.activation.ReLU'> |
||||
sigmoid = Activation("sigmoid") |
||||
print(sigmoid) |
||||
# <class 'paddle.nn.layer.activation.Sigmoid'> |
||||
not_exit_one = Activation("not_exit_one") |
||||
# KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink', |
||||
# 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax', |
||||
# 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])" |
||||
""" |
||||
def __init__(self, act=None): |
||||
super(Activation, self).__init__() |
||||
self._act = act |
||||
upper_act_names = nn.layer.activation.__dict__.keys() |
||||
lower_act_names = [act.lower() for act in upper_act_names] |
||||
act_dict = dict(zip(lower_act_names, upper_act_names)) |
||||
if act is not None: |
||||
if act in act_dict.keys(): |
||||
act_name = act_dict[act] |
||||
self.act_func = eval( |
||||
"nn.layer.activation.{}()".format(act_name)) |
||||
else: |
||||
raise KeyError("{} does not exist in the current {}".format( |
||||
act, act_dict.keys())) |
||||
|
||||
def forward(self, x): |
||||
if self._act is not None: |
||||
return self.act_func(x) |
||||
else: |
||||
return x |
@ -0,0 +1,30 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle.nn as nn |
||||
|
||||
|
||||
def constant_init(param, **kwargs): |
||||
initializer = nn.initializer.Constant(**kwargs) |
||||
initializer(param, param.block) |
||||
|
||||
|
||||
def normal_init(param, **kwargs): |
||||
initializer = nn.initializer.Normal(**kwargs) |
||||
initializer(param, param.block) |
||||
|
||||
|
||||
def kaiming_normal_init(param, **kwargs): |
||||
initializer = nn.initializer.KaimingNormal(**kwargs) |
||||
initializer(param, param.block) |
@ -0,0 +1,23 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle.nn as nn |
||||
|
||||
|
||||
class Identity(nn.Layer): |
||||
def __init__(self, *args, **kwargs): |
||||
super(Identity, self).__init__() |
||||
|
||||
def forward(self, input): |
||||
return input |
@ -0,0 +1,59 @@ |
||||
import sys |
||||
|
||||
sys.path.append("E:/dataFiles/github/PaddleRS") |
||||
|
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
# 下载和解压视盘分割数据集 |
||||
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz' |
||||
pdrs.utils.download_and_decompress(optic_dataset, path='./') |
||||
|
||||
# 定义训练和验证时的transforms |
||||
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md |
||||
train_transforms = T.Compose([ |
||||
T.Resize(target_size=512), |
||||
T.RandomHorizontalFlip(), |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.Resize(target_size=512), |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
]) |
||||
|
||||
# 定义训练和验证所用的数据集 |
||||
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md |
||||
train_dataset = pdrs.datasets.SegDataset( |
||||
data_dir='optic_disc_seg', |
||||
file_list='optic_disc_seg/train_list.txt', |
||||
label_list='optic_disc_seg/labels.txt', |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True) |
||||
|
||||
eval_dataset = pdrs.datasets.SegDataset( |
||||
data_dir='optic_disc_seg', |
||||
file_list='optic_disc_seg/val_list.txt', |
||||
label_list='optic_disc_seg/labels.txt', |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False) |
||||
|
||||
# 初始化模型,并进行训练 |
||||
# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md |
||||
num_classes = len(train_dataset.labels) |
||||
model = pdrs.tasks.FarSeg(num_classes=num_classes) |
||||
|
||||
# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md |
||||
# 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md |
||||
model.train( |
||||
num_epochs=10, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=4, |
||||
eval_dataset=eval_dataset, |
||||
learning_rate=0.01, |
||||
pretrain_weights=None, |
||||
save_dir='output/farseg') |
Loading…
Reference in new issue