[Fix] Fix RS research example bugs

own
Bobholamovic 2 years ago
parent e4ddab56cd
commit 6479c51d4f
  1. 18
      examples/rs_research/custom_trainer.py
  2. 7
      examples/rs_research/predict_cd.py
  3. 9
      examples/rs_research/train_cd.py
  4. 2
      paddlers/tasks/change_detector.py

@ -23,12 +23,12 @@ from attach_tools import Attach
attach = Attach.to(paddlers.tasks.change_detector)
def make_trainer(net_type, *args, **kwargs):
def make_trainer(net_type, attach_trainer=True):
def _init_func(self,
num_classes=2,
use_mixed_loss=False,
losses=None,
**params):
**_params_):
sig = inspect.signature(net_type.__init__)
net_params = {
k: p.default
@ -36,7 +36,13 @@ def make_trainer(net_type, *args, **kwargs):
}
net_params.pop('self', None)
net_params.pop('num_classes', None)
net_params.update(params)
# Special rule to parse arguments from `_params_`.
# When using pdrs.tasks.load_model, `_params_`` is a dict with the key '_params_'.
# This bypasses the dynamic modification/creation of function signature.
if '_params_' not in _params_:
net_params.update(_params_)
else:
net_params.update(_params_['_params_'])
super(trainer_type, self).__init__(
model_name=net_type.__name__,
@ -52,7 +58,13 @@ def make_trainer(net_type, *args, **kwargs):
trainer_type = type(trainer_name, (BaseChangeDetector, ),
{'__init__': _init_func})
if attach_trainer:
trainer_type = attach(trainer_type)
return trainer_type
def make_trainer_and_build(net_type, *args, **kwargs):
trainer_type = make_trainer(net_type, attach_trainer=True)
return trainer_type(*args, **kwargs)

@ -23,8 +23,8 @@ import paddle
import paddlers
from tqdm import tqdm
import custom_model
import custom_trainer
from custom_model import CustomModel
from custom_trainer import make_trainer
def read_file_list(file_list, sep=' '):
@ -57,6 +57,9 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()
# 注册训练器
make_trainer(CustomModel)
model = paddlers.tasks.load_model(args.model_dir)
if not osp.exists(args.save_dir):

@ -7,7 +7,7 @@ import paddlers as pdrs
from paddlers import transforms as T
from custom_model import CustomModel
from custom_trainer import make_trainer
from custom_trainer import make_trainer_and_build
# 数据集路径
DATA_DIR = 'data/levircd/'
@ -75,8 +75,8 @@ test_dataset = pdrs.datasets.CDDataset(
binarize_labels=True)
# 构建自定义模型CustomModel并为其自动生成训练器
# make_trainer()的首个参数为模型类型,剩余参数为模型构造所需参数
model = make_trainer(CustomModel, in_channels=3)
# make_trainer_and_build()的首个参数为模型类型,剩余参数为模型构造所需参数
model = make_trainer_and_build(CustomModel, in_channels=3)
# 构建学习率调度器
# 使用定步长学习率衰减策略
@ -108,4 +108,5 @@ model.train(
# 加载验证集上效果最好的模型
model = pdrs.tasks.load_model(osp.join(EXP_DIR, 'best_model'))
# 在测试集上计算精度指标
model.evaluate(test_dataset)
res = model.evaluate(test_dataset)
print(res)

@ -630,8 +630,6 @@ class BaseChangeDetector(BaseModel):
if isinstance(im1, str) or isinstance(im2, str):
im1 = decode_image(im1, read_raw=True)
im2 = decode_image(im2, read_raw=True)
np.save('im1_whole.npy', im1)
np.save('im2_whole.npy', im2)
ori_shape = im1.shape[:2]
# XXX: sample do not contain 'image_t1' and 'image_t2'.
sample = {'image': im1, 'image2': im2}

Loading…
Cancel
Save