diff --git a/paddlers/utils/checkpoint.py b/paddlers/utils/checkpoint.py index ca59645..859a29a 100644 --- a/paddlers/utils/checkpoint.py +++ b/paddlers/utils/checkpoint.py @@ -86,7 +86,8 @@ seg_pretrain_weights_dict = { 'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'], 'FastSCNN': ['CITYSCAPES'], 'HRNet': ['CITYSCAPES', 'PascalVOC'], - 'BiSeNetV2': ['CITYSCAPES'] + 'BiSeNetV2': ['CITYSCAPES'], + 'FactSeg': ['iSAID'] } cityscapes_weights = { @@ -438,6 +439,11 @@ levircd_weights = { 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams' } +isaid_weights = { + 'FactSeg_iSAID': + 'https://paddlers.bj.bcebos.com/pretrained/seg/isaid/weights/factseg_isaid.pdparams' +} + def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None): if flag is None: @@ -463,6 +469,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None): url = coco_weights[weights_key] elif flag == 'LEVIRCD': url = levircd_weights[weights_key] + elif flag == 'iSAID': + url = isaid_weights[weights_key] else: raise ValueError('Given pretrained weights {} is undefined.'.format( flag)) diff --git a/tutorials/train/semantic_segmentation/factseg.py b/tutorials/train/semantic_segmentation/factseg.py index fb26169..1da6c16 100644 --- a/tutorials/train/semantic_segmentation/factseg.py +++ b/tutorials/train/semantic_segmentation/factseg.py @@ -83,7 +83,8 @@ model.train( # 每多少次迭代记录一次日志 log_interval_steps=4, save_dir=EXP_DIR, - pretrain_weights=None, + # 使用iSAID数据集上的预训练权重 + pretrain_weights='iSAID', # 初始学习率大小 learning_rate=0.001, # 是否使用early stopping策略,当精度不再改善时提前终止训练