Update examples/rs_research

own
Bobholamovic 3 years ago
parent 0ac304641b
commit 5f353e6c51
  1. 134
      examples/rs_research/README.md
  2. 20
      examples/rs_research/attach_tools.py
  3. 253
      examples/rs_research/config_utils.py
  4. 6
      examples/rs_research/configs/levircd/bit.yaml
  5. 12
      examples/rs_research/configs/levircd/custom_model/iterative_bit_iter2_gamma01.yaml
  6. 12
      examples/rs_research/configs/levircd/custom_model/iterative_bit_iter2_gamma02.yaml
  7. 12
      examples/rs_research/configs/levircd/custom_model/iterative_bit_iter2_gamma05.yaml
  8. 12
      examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma01.yaml
  9. 12
      examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma02.yaml
  10. 12
      examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma05.yaml
  11. 12
      examples/rs_research/configs/levircd/custom_model/iterative_bit_iter3_gamma10.yaml
  12. 74
      examples/rs_research/configs/levircd/levircd.yaml
  13. 58
      examples/rs_research/custom_model.py
  14. 29
      examples/rs_research/custom_trainer.py
  15. BIN
      examples/rs_research/params_versus_f1.png
  16. 115
      examples/rs_research/run_task.py
  17. 0
      examples/rs_research/scripts/run_benchmark.sh
  18. 0
      examples/rs_research/scripts/run_parameter_analysis.sh
  19. 0
      examples/rs_research/train.py
  20. 8
      test_tipc/configs/seg/unet/unet.yaml

@ -34,32 +34,136 @@ python ../../tools/prepare_dataset/prepare_svcd.py \
### 3.1 问题分析与思路拟定
科学研究是为了解决实际问题的,本案例也不例外。本案例的研究动机如下:随着深度学习技术应用的不断深入,变化检测领域涌现了许多。与之相对应的是,模型的参数量也越来越大。
随着深度学习技术应用的不断深入,近年来,变化检测领域涌现了许多基于全卷积神经网络(fully convolutional network, FCN)的遥感影像变化检测算法。与基于特征和基于影像块的方法相比,基于FCN的方法具有处理效率高、依赖超参数少等优势,但其缺点在于参数量往往较大,因而对训练样本的数量更为依赖。尽管中、大型变化检测数据集的数量与日俱增,训练样本日益丰富,但深度学习变化检测模型的参数量也越来越大。下图显示了从2018年到2021年一些已发表的文献中提出的基于FCN的变化检测模型的参数量与其在SVCD数据集上取得的F1分数(柱状图中bar的高度与模型参数量成正比):
[近年来变化检测模型]()
![params_versus_f1](params_versus_f1.png)
诚然,。
诚然,增大参数数量在大多数情况下等同于增加模型容量,而模型容量的增加意味着模型拟合能力的提升,从而有助于模型在实验数据集上取得更高的精度指标但是,“更大”一定意味着“更好”吗?答案显然是否定的。在实际应用中,“更大”的遥感影像变化检测模型常常遭遇如下问题:
1. 存储开销。
2. 过拟合。
1. 巨大的参数量意味着巨大的存储开销。在许多实际场景中,硬件资源往往是有限的,过多的模型参数将给部署造成困难
2. 在数据有限的情况下,大模型更易遭受过拟合,其在实验数据集上看起来良好的结果也难以泛化到真实场景
为了解决上述问题,本案例拟提出一种基于网络迭代优化思想的深度学习变化检测算法。本案例的基本思路是,构造一个轻量级的变化检测模型,并以其作为基础迭代单元。每次迭代开始时,由上一次迭代输出的概率图以及原始的输入影像对构造新的输入,实现coarse-to-fine优化。考虑到增加迭代单元的数量将使模型参数量成倍增加,在迭代过程中始终复用同一迭代单元的参数,充分挖掘变化检测网络的拟合能力,迫使其学习到更加有效的特征。这一做法类似[循环神经网络](https://baike.baidu.com/item/%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/23199490)。根据此思路可以绘制框图如下:
本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力?基于这个观点,本案例的基本思路是设计一种基于网络迭代优化思想的深度学习变化检测算法。首先,构造一个轻量级的变化检测模型,并以其作为基础迭代单元。每次迭代开始时,由上一次迭代输出的概率图以及原始的输入影像对构造新的输入,如此逐级实现coarse-to-fine优化。考虑到增加迭代单元的数量将使模型参数量成倍增加,在迭代过程中应始终复用同一迭代单元的参数以充分挖掘变化检测网络的拟合能力,迫使其学习到更加有效的特征。这一做法类似[循环神经网络](https://baike.baidu.com/item/%E5%BE%AA%E7%8E%AF%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/23199490)。根据此思路可以绘制框图如下:
[思路展示]()
![draft](draft.png)
### 3.2 确定baseline
### 3.2 确定baseline模型
科研工作往往需要“站在巨人的肩膀上”,在前人工作的基础上做“增量创新”。因此,对模型设计类工作而言,选用一个合适的baseline网络至关重要。考虑到本案例的出发点是解决,并且使用了。
科研工作往往需要“站在巨人的肩膀上”,在前人工作的基础上做“增量创新”。因此,对模型设计类工作而言,选用一个合适的baseline模型至关重要。考虑到本案例的出发点是解决现有模型参数量过大、冗余特征过多的问题,并且在拟定的解决方案中使用到了循环结构,用作baseline的网络结构必须足够轻量和高效(因为最直接的思路是使用baseline作为基础迭代单元)。为此,本案例选用Bitemporal Image Transformer(BIT)作为baseline。BIT是一个轻量级的深度学习变化检测模型,其基本结构如图所示:
![bit](bit.png)
BIT的核心思想在于,
### 3.3 定义新模型
[算法整体框图]()
确定了基本思路和baseline模型之后,可以绘制如下的算法整体框图:
![framework](framework.png)
依据此框图,即可在。
#### 3.3.1 自定义模型组网
在`custom_model.py`中定义模型的宏观(macro)结构以及组成模型的各个微观(micro)模块。例如,当前`custom_model.py`中定义了迭代版本的BIT模型`IterativeBIT`:
```python
@attach
class IterativeBIT(nn.Layer):
def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
super().__init__()
if num_iters <= 0:
raise ValueError(f"`num_iters` should have positive value, but got {num_iters}.")
self.num_iters = num_iters
self.gamma = gamma
if bit_kwargs is None:
bit_kwargs = dict()
if 'num_classes' in bit_kwargs:
raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
bit_kwargs['num_classes'] = num_classes
self.bit = BIT(**bit_kwargs)
def forward(self, t1, t2):
rate_map = self._init_rate_map(t1.shape)
for it in range(self.num_iters):
# Construct inputs
x1 = self._constr_iter_input(t1, rate_map)
x2 = self._constr_iter_input(t2, rate_map)
# Get logits
logits_list = self.bit(x1, x2)
# Construct rate map
prob_map = F.softmax(logits_list[0], axis=1)
rate_map = self._constr_rate_map(prob_map)
return logits_list
...
```
在编写组网相关代码时请注意以下两点:
1. 所有模型必须为`paddle.nn.Layer`的子类;
2. 包含模型整体逻辑结构的最外层模块须用`@attach`装饰;
3. 对于变化检测任务,`forward()`方法除`self`参数外还接受两个参数`t1`、`t2`,分别表示第一时相和第二时相影像。
关于模型定义的更多细节请参考[API文档]()。
#### 3.3.2 自定义训练器
在`custom_trainer.py`中定义训练器。例如,当前`custom_trainer.py`中定义了与`IterativeBIT`模型对应的训练器:
```python
@attach
class IterativeBIT(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
losses=None,
num_iters=1,
gamma=0.1,
bit_kwargs=None,
**params):
params.update({
'num_iters': num_iters,
'gamma': gamma,
'bit_kwargs': bit_kwargs
})
super().__init__(
model_name='IterativeBIT',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
**params)
```
在编写训练器定义相关代码时请注意以下两点:
1. 对于变化检测任务,训练器必须为`paddlers.tasks.cd.BaseChangeDetector`的子类;
2. 与模型一样,训练器也须用`@attach`装饰;
3. 训练器和模型可以同名。
关于训练器定义的更多细节请参考[API文档]()。
### 3.4 进行参数分析与消融实验
#### 3.4.1 实验设置
#### 3.4.2 实验结果
#### 3.4.2 编写配置文件
#### 3.4.3 实验结果
### 3.5 \*Magic Behind
本小节涉及技术细节,对于本案例来说属于进阶内容,您可以选择性了解。
#### 3.5.1 延迟属性绑定
PaddleRS提供了,只需要。`attach_tools.Attach`对象自动。
#### 3.5.2 非侵入式轻量级配置系统
### 3.5 开展特征可视化实验
@ -75,10 +179,16 @@ python ../../tools/prepare_dataset/prepare_svcd.py \
#### 4.3.2 SVCD数据集上的对比结果
精度、FLOPs、运行时间
精度
## 5 总结与展望
### 5.1 总结
### 5.2 展望
耗时,模型大小,FLOPs
## 参考文献
> [1] Chen, Hao, and Zhenwei Shi. "A spatial-temporal attention-based method and a new dataset for remote sensing image change detection." *Remote Sensing* 12.10 (2020): 1662.

@ -0,0 +1,20 @@
class Attach(object):
def __init__(self, dst):
self.dst = dst
def __call__(self, obj, name=None):
if name is None:
# Automatically get names of functions and classes
name = obj.__name__
if hasattr(self.dst, name):
raise RuntimeError(
f"{self.dst} already has the attribute {name}, which is {getattr(self.dst, name)}."
)
setattr(self.dst, name, obj)
if hasattr(self.dst, '__all__'):
self.dst.__all__.append(name)
return obj
@staticmethod
def to(dst):
return Attach(dst)

@ -0,0 +1,253 @@
#!/usr/bin/env python
import argparse
import os.path as osp
from collections.abc import Mapping
import yaml
def _chain_maps(*maps):
chained = dict()
keys = set().union(*maps)
for key in keys:
vals = [m[key] for m in maps if key in m]
if isinstance(vals[0], Mapping):
chained[key] = _chain_maps(*vals)
else:
chained[key] = vals[0]
return chained
def read_config(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
return cfg or {}
def parse_configs(cfg_path, inherit=True):
if inherit:
cfgs = []
cfgs.append(read_config(cfg_path))
while cfgs[-1].get('_base_'):
base_path = cfgs[-1].pop('_base_')
curr_dir = osp.dirname(cfg_path)
cfgs.append(
read_config(osp.normpath(osp.join(curr_dir, base_path))))
return _chain_maps(*cfgs)
else:
return read_config(cfg_path)
def _cfg2args(cfg, parser, prefix=''):
node_keys = set()
for k, v in cfg.items():
opt = prefix + k
if isinstance(v, list):
if len(v) == 0:
parser.add_argument(
'--' + opt, type=object, nargs='*', default=v)
else:
# Only apply to homogeneous lists
if isinstance(v[0], CfgNode):
node_keys.add(opt)
parser.add_argument(
'--' + opt, type=type(v[0]), nargs='*', default=v)
elif isinstance(v, dict):
# Recursively parse a dict
_, new_node_keys = _cfg2args(v, parser, opt + '.')
node_keys.update(new_node_keys)
elif isinstance(v, CfgNode):
node_keys.add(opt)
_, new_node_keys = _cfg2args(v.to_dict(), parser, opt + '.')
node_keys.update(new_node_keys)
elif isinstance(v, bool):
parser.add_argument('--' + opt, action='store_true', default=v)
else:
parser.add_argument('--' + opt, type=type(v), default=v)
return parser, node_keys
def _args2cfg(cfg, args, node_keys):
args = vars(args)
for k, v in args.items():
pos = k.find('.')
if pos != -1:
# Iteratively parse a dict
dict_ = cfg
while pos != -1:
dict_.setdefault(k[:pos], {})
dict_ = dict_[k[:pos]]
k = k[pos + 1:]
pos = k.find('.')
dict_[k] = v
else:
cfg[k] = v
for k in node_keys:
pos = k.find('.')
if pos != -1:
# Iteratively parse a dict
dict_ = cfg
while pos != -1:
dict_.setdefault(k[:pos], {})
dict_ = dict_[k[:pos]]
k = k[pos + 1:]
pos = k.find('.')
v = dict_[k]
dict_[k] = [CfgNode(v_) for v_ in v] if isinstance(
v, list) else CfgNode(v)
else:
v = cfg[k]
cfg[k] = [CfgNode(v_) for v_ in v] if isinstance(
v, list) else CfgNode(v)
return cfg
def parse_args(*args, **kwargs):
cfg_parser = argparse.ArgumentParser(add_help=False)
cfg_parser.add_argument('--config', type=str, default='')
cfg_parser.add_argument('--inherit_off', action='store_true')
cfg_args = cfg_parser.parse_known_args()[0]
cfg_path = cfg_args.config
inherit_on = not cfg_args.inherit_off
# Main parser
parser = argparse.ArgumentParser(
conflict_handler='resolve', parents=[cfg_parser])
# Global settings
parser.add_argument('cmd', choices=['train', 'eval'])
parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg'])
# Data
parser.add_argument('--datasets', type=dict, default={})
parser.add_argument('--transforms', type=dict, default={})
parser.add_argument('--download_on', action='store_true')
parser.add_argument('--download_url', type=str, default='')
parser.add_argument('--download_path', type=str, default='./')
# Optimizer
parser.add_argument('--optimizer', type=dict, default={})
# Training related
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--save_interval_epochs', type=int, default=1)
parser.add_argument('--log_interval_steps', type=int, default=1)
parser.add_argument('--save_dir', default='../exp/')
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--early_stop', action='store_true')
parser.add_argument('--early_stop_patience', type=int, default=5)
parser.add_argument('--use_vdl', action='store_true')
parser.add_argument('--resume_checkpoint', type=str)
parser.add_argument('--train', type=dict, default={})
# Loss
parser.add_argument('--losses', type=dict, nargs='+', default={})
# Model
parser.add_argument('--model', type=dict, default={})
if osp.exists(cfg_path):
cfg = parse_configs(cfg_path, inherit_on)
parser, node_keys = _cfg2args(cfg, parser, '')
node_keys = sorted(node_keys, reverse=True)
args = parser.parse_args(*args, **kwargs)
return _args2cfg(dict(), args, node_keys)
elif cfg_path != '':
raise FileNotFoundError
else:
args = parser.parse_args()
return _args2cfg(dict(), args, set())
class _CfgNodeMeta(yaml.YAMLObjectMetaclass):
def __call__(cls, obj):
if isinstance(obj, CfgNode):
return obj
return super(_CfgNodeMeta, cls).__call__(obj)
class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta):
yaml_tag = u'!Node'
yaml_loader = yaml.SafeLoader
# By default use a lexical scope
ctx = globals()
def __init__(self, dict_):
super().__init__()
self.type = dict_['type']
self.args = dict_.get('args', [])
self.module = dict_.get('module', '')
@classmethod
def set_context(cls, ctx):
# TODO: Implement dynamic scope with inspect.stack()
old_ctx = cls.ctx
cls.ctx = ctx
return old_ctx
def build_object(self, mod=None):
if mod is None:
mod = self._get_module(self.module)
cls = getattr(mod, self.type)
if isinstance(self.args, list):
args = build_objects(self.args)
obj = cls(*args)
elif isinstance(self.args, dict):
args = build_objects(self.args)
obj = cls(**args)
else:
raise NotImplementedError
return obj
def _get_module(self, s):
mod = None
while s:
idx = s.find('.')
if idx == -1:
next_ = s
s = ''
else:
next_ = s[:idx]
s = s[idx + 1:]
if mod is None:
mod = self.ctx[next_]
else:
mod = getattr(mod, next_)
return mod
@staticmethod
def build_objects(cfg, mod=None):
if isinstance(cfg, list):
return [CfgNode.build_objects(c, mod=mod) for c in cfg]
elif isinstance(cfg, CfgNode):
return cfg.build_object(mod=mod)
elif isinstance(cfg, dict):
return {
k: CfgNode.build_objects(
v, mod=mod)
for k, v in cfg.items()
}
else:
return cfg
def __repr__(self):
return f"(type={self.type}, args={self.args}, module={self.module or ' '})"
@classmethod
def from_yaml(cls, loader, node):
map_ = loader.construct_mapping(node)
return cls(map_)
def items(self):
yield from [('type', self.type), ('args', self.args), ('module',
self.module)]
def to_dict(self):
return dict(self.items())
def build_objects(cfg, mod=None):
return CfgNode.build_objects(cfg, mod=mod)

@ -0,0 +1,6 @@
_base_: ./levircd.yaml
save_dir: ./exp/bit/
model: !Node
type: BIT

@ -0,0 +1,12 @@
_base_: ../levircd.yaml
save_dir: ./exp/custom_model/iter2_gamma01/
model: !Node
type: IterativeBIT
args:
num_iters: 2
gamma: 0.1
num_classes: 2
bit_kwargs:
in_channels: 4

@ -0,0 +1,12 @@
_base_: ../levircd.yaml
save_dir: ./exp/custom_model/iter2_gamma02/
model: !Node
type: IterativeBIT
args:
num_iters: 2
gamma: 0.2
num_classes: 2
bit_kwargs:
in_channels: 4

@ -0,0 +1,12 @@
_base_: ../levircd.yaml
save_dir: ./exp/custom_model/iter2_gamma05/
model: !Node
type: IterativeBIT
args:
num_iters: 2
gamma: 0.5
num_classes: 2
bit_kwargs:
in_channels: 4

@ -0,0 +1,12 @@
_base_: ../levircd.yaml
save_dir: ./exp/custom_model/iter3_gamma01/
model: !Node
type: IterativeBIT
args:
num_iters: 3
gamma: 0.1
num_classes: 2
bit_kwargs:
in_channels: 4

@ -0,0 +1,12 @@
_base_: ../levircd.yaml
save_dir: ./exp/custom_model/iter3_gamma02/
model: !Node
type: IterativeBIT
args:
num_iters: 3
gamma: 0.2
num_classes: 2
bit_kwargs:
in_channels: 4

@ -0,0 +1,12 @@
_base_: ../levircd.yaml
save_dir: ./exp/custom_model/iter3_gamma05/
model: !Node
type: IterativeBIT
args:
num_iters: 3
gamma: 0.5
num_classes: 2
bit_kwargs:
in_channels: 4

@ -0,0 +1,12 @@
_base_: ../levircd.yaml
save_dir: ./exp/custom_model/iter3_gamma10/
model: !Node
type: IterativeBIT
args:
num_iters: 3
gamma: 1.0
num_classes: 2
bit_kwargs:
in_channels: 4

@ -0,0 +1,74 @@
# Basic configurations of LEVIR-CD dataset
datasets:
train: !Node
type: CDDataset
args:
data_dir: ./data/levircd/
file_list: ./data/levircd/train.txt
label_list: null
num_workers: 2
shuffle: True
with_seg_labels: False
binarize_labels: True
eval: !Node
type: CDDataset
args:
data_dir: ./data/levircd/
file_list: ./data/levircd/val.txt
label_list: null
num_workers: 0
shuffle: False
with_seg_labels: False
binarize_labels: True
transforms:
train:
- !Node
type: DecodeImg
- !Node
type: RandomFlipOrRotate
args:
probs: [0.35, 0.35]
probsf: [0.5, 0.5, 0, 0, 0]
probsr: [0.33, 0.34, 0.33]
- !Node
type: Normalize
args:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
- !Node
type: ArrangeChangeDetector
args: ['train']
eval:
- !Node
type: DecodeImg
- !Node
type: Normalize
args:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
- !Node
type: ArrangeChangeDetector
args: ['eval']
download_on: False
num_epochs: 40
train_batch_size: 8
optimizer: !Node
type: Adam
args:
learning_rate: !Node
type: StepDecay
module: paddle.optimizer.lr
args:
learning_rate: 0.002
step_size: 30
gamma: 0.2
save_interval_epochs: 10
log_interval_steps: 500
save_dir: ./exp/
learning_rate: 0.002
early_stop: False
early_stop_patience: 5
use_vdl: True
resume_checkpoint: ''

@ -0,0 +1,58 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddlers
from paddlers.rs_models.cd import BIT
from attach_tools import Attach
attach = Attach.to(paddlers.rs_models.cd)
@attach
class IterativeBIT(nn.Layer):
def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
super().__init__()
if num_iters <= 0:
raise ValueError(
f"`num_iters` should have positive value, but got {num_iters}.")
self.num_iters = num_iters
self.gamma = gamma
if bit_kwargs is None:
bit_kwargs = dict()
if 'num_classes' in bit_kwargs:
raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
bit_kwargs['num_classes'] = num_classes
self.bit = BIT(**bit_kwargs)
def forward(self, t1, t2):
rate_map = self._init_rate_map(t1.shape)
for it in range(self.num_iters):
# Construct inputs
x1 = self._constr_iter_input(t1, rate_map)
x2 = self._constr_iter_input(t2, rate_map)
# Get logits
logits_list = self.bit(x1, x2)
# Construct rate map
prob_map = F.softmax(logits_list[0], axis=1)
rate_map = self._constr_rate_map(prob_map)
return logits_list
def _constr_iter_input(self, im, rate_map):
return paddle.concat([im.rate_map], axis=1)
def _init_rate_map(self, im_shape):
b, _, h, w = im_shape
return paddle.zeros((b, 1, h, w))
def _constr_rate_map(self, prob_map):
if prob_map.shape[1] != 2:
raise ValueError(
f"`prob_map.shape[1]` must be 2, but got {prob_map.shape[1]}.")
return (prob_map[:, 1:2] * self.gamma)

@ -0,0 +1,29 @@
import paddlers
from paddlers.tasks.change_detector import BaseChangeDetector
from attach_tools import Attach
attach = Attach.to(paddlers.tasks.change_detector)
@attach
class IterativeBIT(BaseChangeDetector):
def __init__(self,
num_classes=2,
use_mixed_loss=False,
losses=None,
num_iters=1,
gamma=0.1,
bit_kwargs=None,
**params):
params.update({
'num_iters': num_iters,
'gamma': gamma,
'bit_kwargs': bit_kwargs
})
super().__init__(
model_name='IterativeBIT',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
**params)

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

@ -0,0 +1,115 @@
#!/usr/bin/env python
import os
import paddle
import paddlers
from paddlers import transforms as T
import custom_model
import custom_trainer
from config_utils import parse_args, build_objects, CfgNode
def format_cfg(cfg, indent=0):
s = ''
if isinstance(cfg, dict):
for i, (k, v) in enumerate(sorted(cfg.items())):
s += ' ' * indent + str(k) + ': '
if isinstance(v, (dict, list, CfgNode)):
s += '\n' + format_cfg(v, indent=indent + 1)
else:
s += str(v)
if i != len(cfg) - 1:
s += '\n'
elif isinstance(cfg, list):
for i, v in enumerate(cfg):
s += ' ' * indent + '- '
if isinstance(v, (dict, list, CfgNode)):
s += '\n' + format_cfg(v, indent=indent + 1)
else:
s += str(v)
if i != len(cfg) - 1:
s += '\n'
elif isinstance(cfg, CfgNode):
s += ' ' * indent + f"type: {cfg.type}" + '\n'
s += ' ' * indent + f"module: {cfg.module}" + '\n'
s += ' ' * indent + 'args: \n' + format_cfg(cfg.args, indent + 1)
return s
if __name__ == '__main__':
CfgNode.set_context(globals())
cfg = parse_args()
print(format_cfg(cfg))
# Automatically download data
if cfg['download_on']:
paddlers.utils.download_and_decompress(
cfg['download_url'], path=cfg['download_path'])
if cfg['cmd'] == 'train':
if not isinstance(cfg['datasets']['train'].args, dict):
raise ValueError("args of train dataset must be a dict!")
if cfg['datasets']['train'].args.get('transforms', None) is not None:
raise ValueError(
"Found key 'transforms' in args of train dataset and the value is not None."
)
train_transforms = T.Compose(
build_objects(
cfg['transforms']['train'], mod=T))
# Inplace modification
cfg['datasets']['train'].args['transforms'] = train_transforms
train_dataset = build_objects(
cfg['datasets']['train'], mod=paddlers.datasets)
if not isinstance(cfg['datasets']['eval'].args, dict):
raise ValueError("args of eval dataset must be a dict!")
if cfg['datasets']['eval'].args.get('transforms', None) is not None:
raise ValueError(
"Found key 'transforms' in args of eval dataset and the value is not None."
)
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T))
# Inplace modification
cfg['datasets']['eval'].args['transforms'] = eval_transforms
eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets)
model = build_objects(
cfg['model'], mod=getattr(paddlers.tasks, cfg['task']))
if cfg['cmd'] == 'train':
if cfg['optimizer']:
if len(cfg['optimizer'].args) == 0:
cfg['optimizer'].args = {}
if not isinstance(cfg['optimizer'].args, dict):
raise TypeError("args of optimizer must be a dict!")
if cfg['optimizer'].args.get('parameters', None) is not None:
raise ValueError(
"Found key 'parameters' in args of optimizer and the value is not None."
)
cfg['optimizer'].args['parameters'] = model.net.parameters()
optimizer = build_objects(cfg['optimizer'], mod=paddle.optimizer)
else:
optimizer = None
model.train(
num_epochs=cfg['num_epochs'],
train_dataset=train_dataset,
train_batch_size=cfg['train_batch_size'],
eval_dataset=eval_dataset,
optimizer=optimizer,
save_interval_epochs=cfg['save_interval_epochs'],
log_interval_steps=cfg['log_interval_steps'],
save_dir=cfg['save_dir'],
learning_rate=cfg['learning_rate'],
early_stop=cfg['early_stop'],
early_stop_patience=cfg['early_stop_patience'],
use_vdl=cfg['use_vdl'],
resume_checkpoint=cfg['resume_checkpoint'] or None,
**cfg['train'])
elif cfg['cmd'] == 'eval':
state_dict = paddle.load(
os.path.join(cfg['resume_checkpoint'], 'model.pdparams'))
model.net.set_state_dict(state_dict)
res = model.evaluate(eval_dataset)
print(res)

@ -5,7 +5,7 @@ _base_: ../_base_/rsseg.yaml
save_dir: ./test_tipc/output/seg/unet/
model: !Node
type: UNet
args:
input_channel: 10
num_classes: 5
type: UNet
args:
input_channel: 10
num_classes: 5
Loading…
Cancel
Save