diff --git a/paddlers/utils/checkpoint.py b/paddlers/utils/checkpoint.py index 029c4fc..b6ada51 100644 --- a/paddlers/utils/checkpoint.py +++ b/paddlers/utils/checkpoint.py @@ -21,7 +21,18 @@ import paddle from . import logging from .download import download_and_decompress -cd_pretrain_weights_dict = {} +cd_pretrain_weights_dict = { + 'BIT': ['LEVIRCD'], + 'CDNet': ['LEVIRCD'], + 'DSAMNet': ['LEVIRCD'], + 'DSIFN': ['LEVIRCD'], + 'FCEarlyFusion': ['LEVIRCD'], + 'FCSiamConc': ['LEVIRCD'], + 'FCSiamDiff': ['LEVIRCD'], + 'FCCDN': ['LEVIRCD'], + 'SNUNet': ['LEVIRCD'], + 'STANet': ['LEVIRCD'] +} cls_pretrain_weights_dict = { 'ResNet50_vd': ['IMAGENET'], @@ -404,6 +415,29 @@ coco_weights = { 'https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_vd_fpn_1x_coco.pdparams' } +levircd_weights = { + 'BIT_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/bit_levircd.pdparams', + 'CDNet_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/cdnet_levircd.pdparams', + 'DSAMNet_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsamnet_levircd.pdparams', + 'DSIFN_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/dsifn_levircd.pdparams', + 'FCEarlyFusion_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_ef_levircd.pdparams', + 'FCSiamConc_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_conc_levircd.pdparams', + 'FCSiamDiff_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fc_siam_diff_levircd.pdparams', + 'FCCDN_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/fccdn_levircd.pdparams', + 'SNUNet_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/snunet_levircd.pdparams', + 'STANet_LEVIRCD': + 'https://paddlers.bj.bcebos.com/pretrained/cd/levircd/weights/stanet_levircd.pdparams' +} + def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None): if flag is None: @@ -427,6 +461,8 @@ def get_pretrain_weights(flag, class_name, save_dir, backbone_name=None): url = pascalvoc_weights[weights_key] elif flag == 'COCO': url = coco_weights[weights_key] + elif flag == 'LEVIRCD': + url = levircd_weights[weights_key] else: raise ValueError('Given pretrained weights {} is undefined.'.format( flag)) diff --git a/tutorials/train/change_detection/fccdn.py b/tutorials/train/change_detection/fccdn.py index 62abbba..318fa0e 100644 --- a/tutorials/train/change_detection/fccdn.py +++ b/tutorials/train/change_detection/fccdn.py @@ -78,11 +78,11 @@ model = pdrs.tasks.cd.FCCDN() # 执行模型训练 model.train( - num_epochs=5, + num_epochs=10, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset, - save_interval_epochs=2, + save_interval_epochs=4, # 每多少次迭代记录一次日志 log_interval_steps=50, save_dir=EXP_DIR,