Merge pull request #36 from Bobholamovic/add_cd_pretr

[Feat] Add Pretrained Models for CD Tasks
own
cc 2 years ago committed by GitHub
commit 61c74a24bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 38
      paddlers/utils/checkpoint.py
  2. 4
      tutorials/train/change_detection/fccdn.py

@ -21,7 +21,18 @@ import paddle
from . import logging
from .download import download_and_decompress
cd_pretrain_weights_dict = {}
cd_pretrain_weights_dict = {
'BIT': ['LEVIRCD'],
'CDNet': ['LEVIRCD'],
'DSAMNet': ['LEVIRCD'],
'DSIFN': ['LEVIRCD'],
'FCEarlyFusion': ['LEVIRCD'],
'FCSiamConc': ['LEVIRCD'],
'FCSiamDiff': ['LEVIRCD'],
'FCCDN': ['LEVIRCD'],
'SNUNet': ['LEVIRCD'],
'STANet': ['LEVIRCD']
}
cls_pretrain_weights_dict = {
'ResNet50_vd': ['IMAGENET'],
@ -404,6 +415,29 @@ coco_weights = {
'https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_vd_fpn_1x_coco.pdparams'
}
levircd_weights = {
'BIT_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/bit_levircd.pdparams',
'CDNet_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/cdnet_levircd.pdparams',
'DSAMNet_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsamnet_levircd.pdparams',
'DSIFN_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsifn_levircd.pdparams',
'FCEarlyFusion_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_ef_levircd.pdparams',
'FCSiamConc_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_conc_levircd.pdparams',
'FCSiamDiff_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_diff_levircd.pdparams',
'FCCDN_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fccdn_levircd.pdparams',
'SNUNet_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/snunet_levircd.pdparams',
'STANet_LEVIRCD':
'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams'
}
def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
if flag is None:
@ -427,6 +461,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None):
url = pascalvoc_weights[weights_key]
elif flag == 'COCO':
url = coco_weights[weights_key]
elif flag == 'LEVIRCD':
url = levircd_weights[weights_key]
else:
raise ValueError('Given pretrained weights {} is undefined.'.format(
flag))

@ -78,11 +78,11 @@ model = pdrs.tasks.cd.FCCDN()
# 执行模型训练
model.train(
num_epochs=5,
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=2,
save_interval_epochs=4,
# 每多少次迭代记录一次日志
log_interval_steps=50,
save_dir=EXP_DIR,

Loading…
Cancel
Save