Add iSAID pretrained weight of FactSeg (#57)

own
Lin Manhui 2 years ago committed by GitHub
parent fad0a645b7
commit b75e0d19d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      paddlers/utils/checkpoint.py
  2. 3
      tutorials/train/semantic_segmentation/factseg.py

@ -86,7 +86,8 @@ seg_pretrain_weights_dict = {
'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
'FastSCNN': ['CITYSCAPES'],
'HRNet': ['CITYSCAPES', 'PascalVOC'],
'BiSeNetV2': ['CITYSCAPES']
'BiSeNetV2': ['CITYSCAPES'],
'FactSeg': ['iSAID']
}
cityscapes_weights = {
@ -438,6 +439,11 @@ levircd_weights = {
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams'
}
isaid_weights = {
'FactSeg_iSAID':
'https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/factseg_isaid.pdparams'
}
def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
if flag is None:
@ -463,6 +469,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
url = coco_weights[weights_key]
elif flag == 'LEVIRCD':
url = levircd_weights[weights_key]
elif flag == 'iSAID':
url = isaid_weights[weights_key]
else:
raise ValueError('Given pretrained weights {} is undefined.'.format(
flag))

@ -83,7 +83,8 @@ model.train(
# 每多少次迭代记录一次日志
log_interval_steps=4,
save_dir=EXP_DIR,
pretrain_weights=None,
# 使用iSAID数据集上的预训练权重
pretrain_weights='iSAID',
# 初始学习率大小
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练

Loading…
Cancel
Save