|
|
@ -1,6 +1,6 @@ |
|
|
|
# PaddleRS训练API说明 |
|
|
|
# PaddleRS训练API说明 |
|
|
|
|
|
|
|
|
|
|
|
**训练器**封装了模型训练、验证、量化以及动态图推理等逻辑,定义在`paddlers/tasks/`目录下的文件中。为了方便用户使用,PaddleRS为所有支持的模型均提供了继承自父类[`BaseModel`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/base.py)的训练器,并对外提供数个API。变化检测、场景分类、图像分割以及目标检测任务对应的训练器类型分别为`BaseChangeDetector`、`BaseClassifier`、`BaseDetector`和`BaseSegmenter`。本文档介绍训练器的初始化函数以及`train()`、`evaluate()` API。 |
|
|
|
**训练器**封装了模型训练、验证、量化以及动态图推理等逻辑,定义在`paddlers/tasks/`目录下的文件中。为了方便用户使用,PaddleRS为所有支持的模型均提供了继承自父类[`BaseModel`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/base.py)的训练器,并对外提供数个API。变化检测、场景分类、目标检测、图像复原以及图像分割任务对应的训练器类型分别为`BaseChangeDetector`、`BaseClassifier`、`BaseDetector`、`BaseRestorer`和`BaseSegmenter`。本文档介绍训练器的初始化函数以及`train()`、`evaluate()` API。 |
|
|
|
|
|
|
|
|
|
|
|
## 初始化训练器 |
|
|
|
## 初始化训练器 |
|
|
|
|
|
|
|
|
|
|
@ -28,7 +28,9 @@ |
|
|
|
|
|
|
|
|
|
|
|
### 初始化`BaseRestorer`子类对象 |
|
|
|
### 初始化`BaseRestorer`子类对象 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- 一般支持设置`sr_factor`参数,表示超分辨率倍数;对于不支持超分辨率重建任务的模型,`sr_factor`设置为`None`。 |
|
|
|
|
|
|
|
- 可通过`losses`参数指定模型训练时使用的损失函数,传入实参需为可调用对象或字典。手动指定的`losses`与子类的`default_loss()`方法返回值必须具有相同的格式。 |
|
|
|
|
|
|
|
- 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/res)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py)。 |
|
|
|
|
|
|
|
|
|
|
|
### 初始化`BaseSegmenter`子类对象 |
|
|
|
### 初始化`BaseSegmenter`子类对象 |
|
|
|
|
|
|
|
|
|
|
@ -180,6 +182,46 @@ def train(self, |
|
|
|
|
|
|
|
|
|
|
|
### `BaseRestorer.train()` |
|
|
|
### `BaseRestorer.train()` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
接口形式: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python |
|
|
|
|
|
|
|
def train(self, |
|
|
|
|
|
|
|
num_epochs, |
|
|
|
|
|
|
|
train_dataset, |
|
|
|
|
|
|
|
train_batch_size=2, |
|
|
|
|
|
|
|
eval_dataset=None, |
|
|
|
|
|
|
|
optimizer=None, |
|
|
|
|
|
|
|
save_interval_epochs=1, |
|
|
|
|
|
|
|
log_interval_steps=2, |
|
|
|
|
|
|
|
save_dir='output', |
|
|
|
|
|
|
|
pretrain_weights='CITYSCAPES', |
|
|
|
|
|
|
|
learning_rate=0.01, |
|
|
|
|
|
|
|
lr_decay_power=0.9, |
|
|
|
|
|
|
|
early_stop=False, |
|
|
|
|
|
|
|
early_stop_patience=5, |
|
|
|
|
|
|
|
use_vdl=True, |
|
|
|
|
|
|
|
resume_checkpoint=None): |
|
|
|
|
|
|
|
``` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
其中各参数的含义如下: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|参数名称|类型|参数说明|默认值| |
|
|
|
|
|
|
|
|-------|----|--------|-----| |
|
|
|
|
|
|
|
|`num_epochs`|`int`|训练的epoch数目。|| |
|
|
|
|
|
|
|
|`train_dataset`|`paddlers.datasets.ResDataset`|训练数据集。|| |
|
|
|
|
|
|
|
|`train_batch_size`|`int`|训练时使用的batch size。|`2`| |
|
|
|
|
|
|
|
|`eval_dataset`|`paddlers.datasets.ResDataset` \| `None`|验证数据集。|`None`| |
|
|
|
|
|
|
|
|`optimizer`|`paddle.optimizer.Optimizer` \| `None`|训练时使用的优化器。若为`None`,则使用默认定义的优化器。|`None`| |
|
|
|
|
|
|
|
|`save_interval_epochs`|`int`|训练时存储模型的间隔epoch数。|`1`| |
|
|
|
|
|
|
|
|`log_interval_steps`|`int`|训练时打印日志的间隔step数(即迭代数)。|`2`| |
|
|
|
|
|
|
|
|`save_dir`|`str`|存储模型的路径。|`'output'`| |
|
|
|
|
|
|
|
|`pretrain_weights`|`str` \| `None`|预训练权重的名称/路径。若为`None`,则不适用预训练权重。|`'CITYSCAPES'`| |
|
|
|
|
|
|
|
|`learning_rate`|`float`|训练时使用的学习率大小,适用于默认优化器。|`0.01`| |
|
|
|
|
|
|
|
|`lr_decay_power`|`float`|学习率衰减系数,适用于默认优化器。|`0.9`| |
|
|
|
|
|
|
|
|`early_stop`|`bool`|训练过程是否启用早停策略。|`False`| |
|
|
|
|
|
|
|
|`early_stop_patience`|`int`|启用早停策略时的`patience`参数(参见[`EarlyStop`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/utils/utils.py))。|`5`| |
|
|
|
|
|
|
|
|`use_vdl`|`bool`|是否启用VisualDL日志。|`True`| |
|
|
|
|
|
|
|
|`resume_checkpoint`|`str` \| `None`|检查点路径。PaddleRS支持从检查点(包含先前训练过程中存储的模型权重和优化器权重)继续训练,但需注意`resume_checkpoint`与`pretrain_weights`不得同时设置为`None`以外的值。|`None`| |
|
|
|
|
|
|
|
|
|
|
|
### `BaseSegmenter.train()` |
|
|
|
### `BaseSegmenter.train()` |
|
|
|
|
|
|
|
|
|
|
@ -284,7 +326,7 @@ def evaluate(self, eval_dataset, batch_size=1, return_details=False): |
|
|
|
|
|
|
|
|
|
|
|
``` |
|
|
|
``` |
|
|
|
{"top1": top1准确率, |
|
|
|
{"top1": top1准确率, |
|
|
|
"top5": `top5准确率} |
|
|
|
"top5": top5准确率} |
|
|
|
``` |
|
|
|
``` |
|
|
|
|
|
|
|
|
|
|
|
### `BaseDetector.evaluate()` |
|
|
|
### `BaseDetector.evaluate()` |
|
|
@ -324,6 +366,26 @@ def evaluate(self, |
|
|
|
|
|
|
|
|
|
|
|
### `BaseRestorer.evaluate()` |
|
|
|
### `BaseRestorer.evaluate()` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
接口形式: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
```python |
|
|
|
|
|
|
|
def evaluate(self, eval_dataset, batch_size=1, return_details=False): |
|
|
|
|
|
|
|
``` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
输入参数如下: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|参数名称|类型|参数说明|默认值| |
|
|
|
|
|
|
|
|-------|----|--------|-----| |
|
|
|
|
|
|
|
|`eval_dataset`|`paddlers.datasets.ResDataset`|评估数据集。|| |
|
|
|
|
|
|
|
|`batch_size`|`int`|评估时使用的batch size(多卡训练时,为所有设备合计batch size)。|`1`| |
|
|
|
|
|
|
|
|`return_details`|`bool`|*当前版本请勿手动设置此参数。*|`False`| |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
输出为一个`collections.OrderedDict`对象,包含如下键值对: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
``` |
|
|
|
|
|
|
|
{"psnr": PSNR指标, |
|
|
|
|
|
|
|
"ssim": SSIM指标} |
|
|
|
|
|
|
|
``` |
|
|
|
|
|
|
|
|
|
|
|
### `BaseSegmenter.evaluate()` |
|
|
|
### `BaseSegmenter.evaluate()` |
|
|
|
|
|
|
|
|
|
|
|