Add read_geo_info and update docs

own
Bobholamovic 2 years ago
parent dcf40fa52f
commit a08e504401
  1. 21
      deploy/README.md
  2. 100
      docs/apis/data.md
  3. 184
      docs/apis/infer.md
  4. 58
      docs/apis/train.md
  5. 52
      docs/intro/transforms.md
  6. 4
      paddlers/tasks/classifier.py
  7. 10
      paddlers/tasks/object_detector.py
  8. 19
      paddlers/transforms/__init__.py
  9. 75
      paddlers/transforms/operators.py

@ -21,9 +21,9 @@ from paddlers.deploy import Predictor
# model_dir: 模型路径(必须是导出的部署或量化模型)。 # model_dir: 模型路径(必须是导出的部署或量化模型)。
# use_gpu: 是否使用GPU,默认为False。 # use_gpu: 是否使用GPU,默认为False。
# gpu_id: 使用GPU的ID,默认为0。 # gpu_id: 使用GPU的ID,默认为0。
# cpu_thread_num:使用cpu进行预测时的线程数,默认为1。 # cpu_thread_num:使用CPU进行预测时的线程数,默认为1。
# use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认为False。 # use_mkl: 是否使用MKL-DNN计算库,CPU情况下使用,默认为False。
# mkl_thread_num: mkldnn计算线程数,默认为4。 # mkl_thread_num: MKL-DNN计算线程数,默认为4。
# use_trt: 是否使用TensorRT,默认为False。 # use_trt: 是否使用TensorRT,默认为False。
# use_glog: 是否启用glog日志, 默认为False。 # use_glog: 是否启用glog日志, 默认为False。
# memory_optimize: 是否启动内存优化,默认为True。 # memory_optimize: 是否启动内存优化,默认为True。
@ -34,21 +34,20 @@ from paddlers.deploy import Predictor
predictor = Predictor("static_models/", use_gpu=True) predictor = Predictor("static_models/", use_gpu=True)
# 第二步:调用Predictor的predict()方法执行推理。该方法接受的输入参数如下: # 第二步:调用Predictor的predict()方法执行推理。该方法接受的输入参数如下:
# img_file(List[str or tuple or np.ndarray], str, tuple, or np.ndarray): # img_file: 对于场景分类、图像复原、目标检测和图像分割任务来说,该参数可为单一图像路径,或是解码后的、排列格式为(H, W, C)
# 对于场景分类、图像复原、目标检测和语义分割任务来说,该参数可为单一图像路径,或是解码后的、排列格式为(H, W, C) # 且具有float32类型的图像数据(表示为numpy的ndarray形式),或者是一组图像路径或np.ndarray对象构成的列表;对于变化检测
# 且具有float32类型的BGR图像(表示为numpy的ndarray形式),或者是一组图像路径或np.ndarray对象构成的列表;对于变化检测
# 任务来说,该参数可以为图像路径二元组(分别表示前后两个时相影像路径),或是两幅图像组成的二元组,或者是上述两种二元组 # 任务来说,该参数可以为图像路径二元组(分别表示前后两个时相影像路径),或是两幅图像组成的二元组,或者是上述两种二元组
# 之一构成的列表。 # 之一构成的列表。
# topk(int): 场景分类模型预测时使用,表示预测前topk的结果。默认值为1。 # topk: 场景分类模型预测时使用,表示选取模型输出概率大小排名前`topk`的类别作为最终结果。默认值为1。
# transforms (paddlers.transforms): 数据预处理操作。默认值为None, 即使用`model.yml`中保存的数据预处理操作 # transforms: 对输入数据应用的数据变换算子。若为None,则使用从`model.yml`中读取的算子。默认值为None
# warmup_iters (int): 预热轮数,用于评估模型推理以及前后处理速度。若大于1,会预先重复预测warmup_iters,而后才开始正式的预测及其速度评估。默认为0。 # warmup_iters: 预热轮数,用于评估模型推理以及前后处理速度。若大于1,会预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。默认为0。
# repeats (int): 重复次数,用于评估模型推理以及前后处理速度。若大于1,会预测repeats次取时间平均值。默认值为1。 # repeats: 重复次数,用于评估模型推理以及前后处理速度。若大于1,会执行`repeats`次预测并取时间平均值。默认值为1。
# #
# 下面的语句传入两幅输入影像的路径 # 下面的语句传入两幅输入影像的路径
res = predictor.predict(("demo_data/A.png", "demo_data/B.png")) res = predictor.predict(("demo_data/A.png", "demo_data/B.png"))
# 第三步:解析predict()方法返回的结果。 # 第三步:解析predict()方法返回的结果。
# 对于语义分割和变化检测任务而言,predict()方法返回的结果为一个字典或字典构成的列表。字典中的`label_map`键对应的值为类别标签图,对于二值变化检测 # 对于图像分割和变化检测任务而言,predict()方法返回的结果为一个字典或字典构成的列表。字典中的`label_map`键对应的值为类别标签图,对于二值变化检测
# 任务而言只有0(不变类)或者1(变化类)两种取值;`score_map`键对应的值为类别概率图,对于二值变化检测任务来说一般包含两个通道,第0个通道表示不发生 # 任务而言只有0(不变类)或者1(变化类)两种取值;`score_map`键对应的值为类别概率图,对于二值变化检测任务来说一般包含两个通道,第0个通道表示不发生
# 变化的概率,第1个通道表示发生变化的概率。如果返回的结果是由字典构成的列表,则列表中的第n项与输入的img_file中的第n项对应。 # 变化的概率,第1个通道表示发生变化的概率。如果返回的结果是由字典构成的列表,则列表中的第n项与输入的img_file中的第n项对应。
# #

@ -2,10 +2,102 @@
## 数据集 ## 数据集
在PaddleRS中,所有数据集均继承自 在PaddleRS中,所有数据集均继承自父类[`BaseDataset`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/base.py)。
## 数据预处理/数据增强算子 ### `CDDataset`
## 组合数据处理算子 https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/cd_dataset.py
## `decode_image()` ### `ClasDataset`
https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/clas_dataset.py
### `COCODetDataset`
https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/coco.py
### `VOCDetDataset`
https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/voc.py
### `SegDataset`
https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/datasets/seg_dataset.py
## 数据读取API
遥感影像的来源多样,数据格式十分繁杂。PaddleRS为不同类型、不同格式的遥感影像提供了统一的读取接口。目前,PaddleRS支持.png、.jpg、.bmp、.npy等常见文件格式的读取,也支持处理遥感领域常用的GeoTiff、img等影像格式。
根据实际需要,用户可以选择`paddlers.transforms.decode_image()`或`paddlers.transforms.DecodeImg`进行数据读取。`DecodeImg`是[数据变换算子](#数据变换算子)之一,可以与其它算子组合使用。`decode_image`是对`DecodeImg`算子的封装,方便用户以函数调用的方式使用。
`decode_image()`函数的参数列表如下:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`im_path`|`str`|输入图像路径。||
|`to_rgb`|`bool`|若为`True`,则执行BGR到RGB格式的转换。|`True`|
|`to_uint8`|`bool`|若为`True`,则将读取的图像数据量化并转换为uint8类型。|`True`|
|`decode_bgr`|`bool`|若为`True`,则自动将非地学格式影像(如jpeg影像)解析为BGR格式。|`True`|
|`decode_sar`|`bool`|若为`True`,则自动将2通道的地学格式影像(如GeoTiff影像)作为SAR影像解析。|`True`|
|`read_geo_info`|`bool`|若为`True`,则从影像中读取地理信息。|`False`|
返回格式如下:
- 若`read_geo_info`为`False`,则以np.ndarray形式返回读取的影像数据([h, w, c]排布);
- 若`read_geo_info`为`True`,则返回一个二元组,其中第一个元素为读取的影像数据,第二个元素为一个字典,其中的键值对为影像的地理信息,如地理变换信息、地理投影信息等。
## 数据变换算子
在PaddleRS中定义了一系列类,这些类在实例化之后,可通过调用`__call__`方法执行某种特定的数据预处理或数据增强操作。PaddleRS将这些类称为数据预处理/数据增强算子,并统称为**数据变换算子**。所有数据变换算子均继承自父类[`Transform`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/transforms/operators.py)。
### `Transform`
`Transform`对象的`__call__`方法接受唯一的参数`sample`。`sample`必须为字典或字典构成的序列。当`sample`是序列时,为`sample`中的每个字典执行数据变换操作,并将变换结果依次存储在一个Python built-in list中返回;当`sample`是字典时,`Transform`对象根据其中的一些键值对提取输入(这些键称为“输入键”),执行变换后,将结果以键值对的形式写入`sample`中(这些键称为“输出键”)。需要注意的是,目前PaddleRS中许多`Transform`对象都存在复写行为,即,输入键与输出键之间存在交集。`sample`中常见的键名及其表示的含义如下表:
|键名|说明|
|----|----|
|`'image'`|影像路径或数据。对于变化检测任务,指第一时相影像数据。|
|`'image2'`|变化检测任务中第二时相影像数据。|
|`'image_t1'`|变化检测任务中第一时相影像路径。|
|`'image_t2'`|变化检测任务中第二时相影像路径。|
|`'mask'`|图像分割/变化检测任务中的真值标签路径或数据。|
|`'aux_masks'`|图像分割/变化检测任务中的辅助标签路径或数据。|
|`'gt_bbox'`|目标检测任务中的检测框标注数据。|
|`'gt_poly'`|目标检测任务中的多边形标注数据。|
## 组合数据变换算子
使用`paddlers.transforms.Compose`对一组数据变换算子进行组合。`Compose`对象在构造时接受一个列表输入。在调用`Compose`对象时,相当于串行执行列表中的每一个数据变换算子。示例如下:
```python
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将影像缩放到512x512大小
T.Resize(target_size=512),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
# 挑选并排布后续需要使用的信息
T.ArrangeSegmenter('train')
])
```
一般来说,`Compose`对象接受的数据变换算子列表中,首个元素为`paddlers.transforms.DecodeImg`对象,用于读取影像数据;最后一个元素为[`Arrange`算子](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/transforms/operators.py),用于从`sample`字典中抽取信息并排列。
对于图像分割任务和变化检测任务的验证集而言,可在`Arrange`算子之前插入`ReloadMask`算子以重新加载真值标签。示例如下:
```python
eval_transforms = T.Compose([
T.DecodeImg(),
T.Resize(target_size=512),
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
# 重新加载标签
T.ReloadMask(),
T.ArrangeSegmenter('eval')
])
```

@ -1,11 +1,195 @@
# PaddleRS推理API说明 # PaddleRS推理API说明
PaddleRS的动态图推理和静态图推理能力分别由训练器([`BaseModel`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/base.py)及其子类)和**预测器**(`paddlers.deploy.Predictor`)提供。
## 动态图推理API ## 动态图推理API
### 整图推理 ### 整图推理
#### `BaseChangeDetector.predict()`
接口形式:
```python
def predict(self, img_file, transforms=None):
```
输入参数:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`img_file`|`list[tuple]` \| `tuple[str\|np.ndarray]`|输入影像对数据(NumPy数组形式)或输入影像对路径。若仅预测一个影像对,使用一个元组顺序包含第一时相影像数据/路径以及第二时相影像数据/路径。若需要一次性预测一组影像对,以列表包含这些影像对的数据或路径(每个影像对对应列表中的一个元组)。||
|`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
返回格式:
若`img_file`是一个元组,则返回对象为包含下列键值对的字典:
```
{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
```
若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。
#### `BaseClassifier.predict()`
接口形式:
```python
def predict(self, img_file, transforms=None):
```
输入参数:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`img_file`|`list[str\|np.ndarray]` \| `str` \| `np.ndarray`|输入影像数据(NumPy数组形式)或输入影像路径。若需要一次性预测一组影像,以列表包含这些影像的数据或路径(每幅影像对应列表中的一个元素)。||
|`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
返回格式:
若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
```
{"label map": 输出类别标签,
"scores_map": 输出类别概率,
"label_names_map": 输出类别名称}
```
若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。
#### `BaseDetector.predict()`
接口形式:
```python
def predict(self, img_file, transforms=None):
```
输入参数:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`img_file`|`list[str\|np.ndarray]` \| `str` \| `np.ndarray`|输入影像数据(NumPy数组形式)或输入影像路径。若需要一次性预测一组影像,以列表包含这些影像的数据或路径(每幅影像对应列表中的一个元素)。||
|`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
返回格式:
若`img_file`是一个字符串或NumPy数组,则返回对象为一个列表,列表中每个元素对应一个预测的目标框。列表中的元素为包含下列键值对的字典:
```
{"category_id": 类别ID,
"category": 类别名称,
"bbox": 目标框位置信息,依次包含目标框左上角的横、纵坐标以及目标框的宽度和长度,
"score": 类别置信度,
"mask": [RLE格式](https://baike.baidu.com/item/rle/366352)的掩模图(mask),仅实例分割模型预测结果包含此键值对}
```
若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个由字典(键值对如上所示)构成的列表,顺序对应`img_file`中的每个元素。
#### `BaseSegmenter.predict()`
接口形式:
```python
def predict(self, img_file, transforms=None):
```
输入参数:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`img_file`|`list[str\|np.ndarray]` \| `str` \| `np.ndarray`|输入影像数据(NumPy数组形式)或输入影像路径。若需要一次性预测一组影像,以列表包含这些影像的数据或路径(每幅影像对应列表中的一个元素)。||
|`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
返回格式:
若`img_file`是一个字符串或NumPy数组,则返回对象为包含下列键值对的字典:
```
{"label map": 输出类别标签(以[h, w]格式排布),"score_map": 模型输出的各类别概率(以[h, w, c]格式排布)}
```
若`img_file`是一个列表,则返回对象为与`img_file`等长的列表,其中的每一项为一个字典(键值对如上所示),顺序对应`img_file`中的每个元素。
### 滑窗推理 ### 滑窗推理
考虑到遥感影像的大幅面性质,PaddleRS为部分任务提供了滑窗推理支持。PaddleRS的滑窗推理功能具有如下特色:
1. 为了解决一次读入整张大图直接导致内存不足的问题,PaddleRS采用延迟载入内存的技术,一次仅读取并处理一个窗口内的影像块。
2. 用户可自定义滑窗的大小和步长。支持滑窗重叠,对于窗口之间重叠的部分,PaddleRS将自动对模型预测结果进行融合。
3. 支持将推理结果保存为GeoTiff格式,支持对地理变换信息、地理投影信息的读取与写入。
目前,图像分割训练器([`BaseSegmenter`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py)及其子类)与变化检测训练器([`BaseChangeDetector`](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py)及其子类)具有动态图滑窗推理API,以图像分割任务的API为例,说明如下:
接口形式:
```python
def slider_predict(self,
img_file,
save_dir,
block_size,
overlap=36,
transforms=None):
```
输入参数列表:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`img_file`|`str`|输入影像路径。||
|`save_dir`|`str`|预测结果输出路径。||
|`block_size`|`list[int]` \| `tuple[int]` \| `int`|滑窗的窗口大小(以列表或元组指定长、宽或以一个整数指定相同的长宽)。||
|`overlap`|`list[int]` \| `tuple[int]` \| `int`|滑窗的滑动步长(以列表或元组指定长、宽或以一个整数指定相同的长宽)。|`36`|
|`transforms`|`paddlers.transforms.Compose` \| `None`|对输入数据应用的数据变换算子。若为`None`,则使用训练器在验证阶段使用的数据变换算子。|`None`|
变化检测任务的滑窗推理API与图像分割任务类似,但需要注意的是输出结果中存储的地理变换、投影等信息以从第一时相影像中读取的信息为准。
## 静态图推理API ## 静态图推理API
### Python API ### Python API
[将模型导出为部署格式](https://github.com/PaddlePaddle/PaddleRS/blob/develop/deploy/export/README.md)或执行模型量化后,PaddleRS提供`paddlers.deploy.Predictor`用于加载部署或量化格式模型以及执行基于[Paddle Inference](https://www.paddlepaddle.org.cn/tutorials/projectdetail/3952715)的推理。
#### 初始化`Predictor`对象
`Predictor.__init__()`接受如下参数:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`model_dir`|`str`|模型路径(必须是导出的部署或量化模型)。||
|`use_gpu`|`bool`|是否使用GPU。|`False`|
|`gpu_id`|`int`|使用GPU的ID。|`0`|
|`cpu_thread_num`|`int`|使用CPU执行推理时的线程数。|`1`|
|`use_mkl`|`bool`|是否使用MKL-DNN计算库(此选项仅在使用CPU执行推理时生效)。|`False`|
|`mkl_thread_num`|`int`|MKL-DNN计算线程数。|`4`|
|`use_trt`|`bool`|是否使用TensorRT。|`False`|
|`use_glog`|`bool`|是否启用glog日志。|`False`|
|`memory_optimize`|`bool`|是否启用内存优化。|`True`|
|`max_trt_batch_size`|`int`|在使用TensorRT时配置的最大batch size。|`1`|
|`trt_precision_mode`|`str`|在使用TensorRT时采用的精度,可选值为`'float32'`或`'float16'`。|`'float32'`|
#### `Predictor.predict()`
接口形式:
```python
def predict(self,
img_file,
topk=1,
transforms=None,
warmup_iters=0,
repeats=1):
```
输入参数列表:
|参数名称|类型|参数说明|默认值|
|-------|----|--------|-----|
|`img_file`|`list[str\|tuple\|np.ndarray]` \| `str` \| `tuple` \| `np.ndarray`|对于场景分类、目标检测和图像分割任务来说,该参数可为单一图像路径,或是解码后的、排列格式为[h, w, c]且具有float32类型的图像数据(表示为NumPy数组形式),或者是一组图像路径或np.ndarray对象构成的列表;对于变化检测任务来说,该参数可以为图像路径二元组(分别表示前后两个时相影像路径),或是解码后的两幅图像组成的二元组,或者是上述两种二元组之一构成的列表。||
|`topk`|`int`|场景分类模型预测时使用,表示选取模型输出概率大小排名前`topk`的类别作为最终结果。|`1`|
|`transforms`|`paddlers.transforms.Compose`\|`None`|对输入数据应用的数据变换算子。若为`None`,则使用从`model.yml`中读取的算子。|`None`|
|`warmup_iters`|`int`|预热轮数,用于评估模型推理以及前后处理速度。若大于1,将预先重复执行`warmup_iters`次推理,而后才开始正式的预测及其速度评估。|`0`|
|`repeats`|`int`|重复次数,用于评估模型推理以及前后处理速度。若大于1,将执行`repeats`次预测并取时间平均值。|`1`|
`Predictor.predict()`的返回格式与相应的动态图推理API的返回格式完全相同,详情请参考[动态图推理API](#动态图推理api)。

@ -33,6 +33,8 @@
### `BaseChangeDetector.train()` ### `BaseChangeDetector.train()`
接口形式:
```python ```python
def train(self, def train(self,
num_epochs, num_epochs,
@ -74,6 +76,8 @@ def train(self,
### `BaseClassifier.train()` ### `BaseClassifier.train()`
接口形式:
```python ```python
def train(self, def train(self,
num_epochs, num_epochs,
@ -115,6 +119,8 @@ def train(self,
### `BaseDetector.train()` ### `BaseDetector.train()`
接口形式:
```python ```python
def train(self, def train(self,
num_epochs, num_epochs,
@ -166,6 +172,8 @@ def train(self,
### `BaseSegmenter.train()` ### `BaseSegmenter.train()`
接口形式:
```python ```python
def train(self, def train(self,
num_epochs, num_epochs,
@ -209,6 +217,8 @@ def train(self,
### `BaseChangeDetector.evaluate()` ### `BaseChangeDetector.evaluate()`
接口形式:
```python ```python
def evaluate(self, eval_dataset, batch_size=1, return_details=False): def evaluate(self, eval_dataset, batch_size=1, return_details=False):
``` ```
@ -224,27 +234,29 @@ def evaluate(self, eval_dataset, batch_size=1, return_details=False):
当`return_details`为`False`(默认行为)时,输出为一个`collections.OrderedDict`对象。对于二类变化检测任务,输出包含如下键值对: 当`return_details`为`False`(默认行为)时,输出为一个`collections.OrderedDict`对象。对于二类变化检测任务,输出包含如下键值对:
``` ```
{"iou": 变化类的IoU指标 {"iou": 变化类的IoU指标,
"f1": 变化类的F1分数, "f1": 变化类的F1分数,
"oacc": 总体精度(准确率), "oacc": 总体精度(准确率),
"kappa": kappa系数} "kappa": kappa系数}
``` ```
对于多类变化检测任务,输出包含如下键值对: 对于多类变化检测任务,输出包含如下键值对:
``` ```
{"miou": mIoU指标 {"miou": mIoU指标,
"category_iou": 各类的IoU指标, "category_iou": 各类的IoU指标,
"oacc": 总体精度(准确率), "oacc": 总体精度(准确率),
"category_acc": 各类精确率, "category_acc": 各类精确率,
"kappa": kappa系数, "kappa": kappa系数,
"category_F1score": 各类F1分数} "category_F1score": 各类F1分数}
``` ```
当`return_details`为`True`时,返回一个由两个字典构成的二元组,其中第一个元素为上述评价指标,第二个元素为仅包含一个key的字典,其`'confusion_matrix'`键对应值为以Python built-in list存储的混淆矩阵。 当`return_details`为`True`时,返回一个由两个字典构成的二元组,其中第一个元素为上述评价指标,第二个元素为仅包含一个key的字典,其`'confusion_matrix'`键对应值为以Python built-in list存储的混淆矩阵。
### `BaseClassifier.evaluate()` ### `BaseClassifier.evaluate()`
接口形式:
```python ```python
def evaluate(self, eval_dataset, batch_size=1, return_details=False): def evaluate(self, eval_dataset, batch_size=1, return_details=False):
``` ```
@ -260,12 +272,14 @@ def evaluate(self, eval_dataset, batch_size=1, return_details=False):
输出为一个`collections.OrderedDict`对象,包含如下键值对: 输出为一个`collections.OrderedDict`对象,包含如下键值对:
``` ```
{"top1": top1准确率 {"top1": top1准确率,
"top5": `top5准确率} "top5": `top5准确率}
``` ```
### `BaseDetector.evaluate()` ### `BaseDetector.evaluate()`
接口形式:
```python ```python
def evaluate(self, def evaluate(self,
eval_dataset, eval_dataset,
@ -292,13 +306,15 @@ def evaluate(self,
当`return_details`为`True`时,返回一个由两个字典构成的二元组,其中第一个字典为上述评价指标,第二个字典包含如下3个键值对: 当`return_details`为`True`时,返回一个由两个字典构成的二元组,其中第一个字典为上述评价指标,第二个字典包含如下3个键值对:
``` ```
{"gt": 数据集标注信息 {"gt": 数据集标注信息,
"bbox": 预测得到的目标框信息, "bbox": 预测得到的目标框信息,
"mask": 预测得到的掩模图信息} "mask": 预测得到的掩模图信息}
``` ```
### `BaseSegmenter.evaluate()` ### `BaseSegmenter.evaluate()`
接口形式:
```python ```python
def evaluate(self, eval_dataset, batch_size=1, return_details=False): def evaluate(self, eval_dataset, batch_size=1, return_details=False):
``` ```
@ -314,12 +330,12 @@ def evaluate(self, eval_dataset, batch_size=1, return_details=False):
当`return_details`为`False`(默认行为)时,输出为一个`collections.OrderedDict`对象,包含如下键值对: 当`return_details`为`False`(默认行为)时,输出为一个`collections.OrderedDict`对象,包含如下键值对:
``` ```
{"miou": mIoU指标 {"miou": mIoU指标,
"category_iou": 各类的IoU指标, "category_iou": 各类的IoU指标,
"oacc": 总体精度(准确率), "oacc": 总体精度(准确率),
"category_acc": 各类精确率, "category_acc": 各类精确率,
"kappa": kappa系数, "kappa": kappa系数,
"category_F1score": 各类F1分数} "category_F1score": 各类F1分数}
``` ```
当`return_details`为`True`时,返回一个由两个字典构成的二元组,其中第一个元素为上述评价指标,第二个元素为仅包含一个key的字典,其`'confusion_matrix'`键对应值为以Python built-in list存储的混淆矩阵。 当`return_details`为`True`时,返回一个由两个字典构成的二元组,其中第一个元素为上述评价指标,第二个元素为仅包含一个key的字典,其`'confusion_matrix'`键对应值为以Python built-in list存储的混淆矩阵。

@ -1,37 +1,33 @@
# 数据预处理/数据增强 # 数据预处理/数据增强
## 读取各种格式的遥感影像 ## PaddleRS已支持的数据变换算子列表
遥感影像的来源多样,数据格式十分繁杂。PaddleRS为不同类型、不同格式的遥感影像提供了统一的读取接口,只需向`paddlers.transforms.decode_image()`函数传入影像路径,即可将其中的数据信息读取至内存。目前,`paddlers.transforms.decode_image()`支持.png、.jpg、.bmp、.npy等常见文件格式,也支持遥感领域常用的GeoTiff、img等影像格式。 PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为数据变换)策略进行了有机整合,设计统一的算子。考虑到遥感影像的多波段特性,PaddleRS的大部分数据处理算子均能够处理任意数量波段的输入。PaddleRS目前提供的所有数据变换算子如下表:
## PaddleRS已支持的数据预处理/数据增强算子列表 | 数据变换算子名 | 用途 | 任务 | ... |
PaddleRS对不同遥感任务需要的数据预处理/数据增强策略进行了有机整合,设计统一的算子。考虑到遥感影像的多波段特性,PaddleRS的大部分数据处理算子均能够处理任意数量波段的输入。PaddleRS目前提供的所有数据预处理/数据增强算子如下表:
| 数据预处理/数据增强算子名 | 用途 | 任务 | ... |
| -------------------- | ------------------------------------------------- | -------- | ---- | | -------------------- | ------------------------------------------------- | -------- | ---- |
| Resize | 调整输入影像大小 | 所有任务 | ... | | Resize | 调整输入影像大小。 | 所有任务 | ... |
| RandomResize | 随机调整输入影像大小 | 所有任务 | ... | | RandomResize | 随机调整输入影像大小。 | 所有任务 | ... |
| ResizeByShort | 调整输入影像大小,保持纵横比不变(根据短边计算缩放系数) | 所有任务 | ... | | ResizeByShort | 调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... |
| RandomResizeByShort | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数) | 所有任务 | ... | | RandomResizeByShort | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数) | 所有任务 | ... |
| ResizeByLong | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数) | 所有任务 | ... | | ResizeByLong | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数) | 所有任务 | ... |
| RandomHorizontalFlip | 随机水平翻转输入影像 | 所有任务 | ... | | RandomHorizontalFlip | 随机水平翻转输入影像 | 所有任务 | ... |
| RandomVerticalFlip | 随机竖直翻转输入影像 | 所有任务 | ... | | RandomVerticalFlip | 随机竖直翻转输入影像 | 所有任务 | ... |
| Normalize | 对输入影像应用标准化 | 所有任务 | ... | | Normalize | 对输入影像应用标准化 | 所有任务 | ... |
| CenterCrop | 对输入影像进行中心裁剪 | 所有任务 | ... | | CenterCrop | 对输入影像进行中心裁剪 | 所有任务 | ... |
| RandomCrop | 对输入影像进行随机中心裁剪 | 所有任务 | ... | | RandomCrop | 对输入影像进行随机中心裁剪 | 所有任务 | ... |
| RandomScaleAspect | 裁剪输入影像并重新缩放到原始尺寸 | 所有任务 | ... | | RandomScaleAspect | 裁剪输入影像并重新缩放到原始尺寸 | 所有任务 | ... |
| RandomExpand | 根据随机偏移扩展输入影像 | 所有任务 | ... | | RandomExpand | 根据随机偏移扩展输入影像 | 所有任务 | ... |
| Pad | 将输入影像填充到指定的大小 | 所有任务 | ... | | Pad | 将输入影像填充到指定的大小 | 所有任务 | ... |
| MixupImage | 将两幅影像(及对应的目标检测标注)混合在一起作为新的样本 | 目标检测 | ... | | MixupImage | 将两幅影像(及对应的目标检测标注)混合在一起作为新的样本 | 目标检测 | ... |
| RandomDistort | 对输入施加随机色彩变换 | 所有任务 | ... | | RandomDistort | 对输入施加随机色彩变换 | 所有任务 | ... |
| RandomBlur | 对输入施加随机模糊 | 所有任务 | ... | | RandomBlur | 对输入施加随机模糊 | 所有任务 | ... |
| Dehaze | 对输入图像进行去雾 | 所有任务 | ... | | Dehaze | 对输入图像进行去雾 | 所有任务 | ... |
| ReduceDim | 对输入图像进行波段降维 | 所有任务 | ... | | ReduceDim | 对输入图像进行波段降维 | 所有任务 | ... |
| SelectBand | 对输入影像进行波段选择 | 所有任务 | ... | | SelectBand | 对输入影像进行波段选择 | 所有任务 | ... |
| RandomSwap | 随机交换两个时相的输入影像 | 变化检测 | ... | | RandomSwap | 随机交换两个时相的输入影像 | 变化检测 | ... |
| ... | ... | ... | ... | | ... | ... | ... | ... |
## 组合算子 ## 组合算子
在实际的模型训练过程中,常常需要组合多种数据预处理与数据增强策略。PaddleRS提供了`paddlers.transforms.Compose`类以便捷地组合多个数据预处理/数据增强算子,使这些算子能够串行执行。关于`paddlers.transforms.Compose`的具体用法请参见[API说明](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md)。 在实际的模型训练过程中,常常需要组合多种数据预处理与数据增强策略。PaddleRS提供了`paddlers.transforms.Compose`以便捷地组合多个数据变换算子,使这些算子能够串行执行。关于`paddlers.transforms.Compose`的具体用法请参见[API说明](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md)。

@ -438,7 +438,9 @@ class BaseClassifier(BaseModel):
Returns: Returns:
If `img_file` is a string or np.array, the result is a dict with key-value If `img_file` is a string or np.array, the result is a dict with key-value
pairs: pairs:
{"label map": `class_ids_map`, "scores_map": `label_names_map`}. {"label map": `class_ids_map`,
"scores_map": `scores_map`,
"label_names_map": `label_names_map`}.
If `img_file` is a list, the result is a list composed of dicts with the If `img_file` is a list, the result is a list composed of dicts with the
corresponding fields: corresponding fields:
class_ids_map (np.ndarray): class_ids class_ids_map (np.ndarray): class_ids

@ -559,9 +559,13 @@ class BaseDetector(BaseModel):
Returns: Returns:
If `img_file` is a string or np.array, the result is a list of dict with If `img_file` is a string or np.array, the result is a list of dict with
key-value pairs: key-value pairs:
{"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}. {"category_id": `category_id`,
If `img_file` is a list, the result is a list composed of dicts with the "category": `category`,
corresponding fields: "bbox": `[x, y, w, h]`,
"score": `score`,
"mask": `mask`}.
If `img_file` is a list, the result is a list composed of list of dicts
with the corresponding fields:
category_id(int): the predicted category ID. 0 represents the first category_id(int): the predicted category ID. 0 represents the first
category in the dataset, and so on. category in the dataset, and so on.
category(str): category name category(str): category name

@ -24,11 +24,13 @@ def decode_image(im_path,
to_rgb=True, to_rgb=True,
to_uint8=True, to_uint8=True,
decode_bgr=True, decode_bgr=True,
decode_sar=True): decode_sar=True,
read_geo_info=False):
""" """
Decode an image. Decode an image.
Args: Args:
im_path (str): Path of the image to decode.
to_rgb (bool, optional): If True, convert input image(s) from BGR format to to_rgb (bool, optional): If True, convert input image(s) from BGR format to
RGB format. Defaults to True. RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to
@ -38,9 +40,14 @@ def decode_image(im_path,
decode_sar (bool, optional): If True, automatically interpret a two-channel decode_sar (bool, optional): If True, automatically interpret a two-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True. True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False.
Returns: Returns:
np.ndarray: Decoded image. np.ndarray|tuple: If `read_geo_info` is False, return the decoded image.
Otherwise, return a tuple that contains the decoded image and a dictionary
of geographical information (e.g. geographical transform and geographical
projection).
""" """
# Do a presence check. osp.exists() assumes `im_path` is a path-like object. # Do a presence check. osp.exists() assumes `im_path` is a path-like object.
@ -50,11 +57,15 @@ def decode_image(im_path,
to_rgb=to_rgb, to_rgb=to_rgb,
to_uint8=to_uint8, to_uint8=to_uint8,
decode_bgr=decode_bgr, decode_bgr=decode_bgr,
decode_sar=decode_sar) decode_sar=decode_sar,
read_geo_info=read_geo_info)
# Deepcopy to avoid inplace modification # Deepcopy to avoid inplace modification
sample = {'image': copy.deepcopy(im_path)} sample = {'image': copy.deepcopy(im_path)}
sample = decoder(sample) sample = decoder(sample)
return sample['image'] if read_geo_info:
return sample['image'], sample['geo_info_dict']
else:
return sample['image']
def build_transforms(transforms_info): def build_transforms(transforms_info):

@ -180,22 +180,28 @@ class DecodeImg(Transform):
decode_sar (bool, optional): If True, automatically interpret a two-channel decode_sar (bool, optional): If True, automatically interpret a two-channel
geo image (e.g. geotiff images) as a SAR image, set this argument to geo image (e.g. geotiff images) as a SAR image, set this argument to
True. Defaults to True. True. Defaults to True.
read_geo_info (bool, optional): If True, read geographical information from
the image. Deafults to False.
""" """
def __init__(self, def __init__(self,
to_rgb=True, to_rgb=True,
to_uint8=True, to_uint8=True,
decode_bgr=True, decode_bgr=True,
decode_sar=True): decode_sar=True,
read_geo_info=False):
super(DecodeImg, self).__init__() super(DecodeImg, self).__init__()
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.to_uint8 = to_uint8 self.to_uint8 = to_uint8
self.decode_bgr = decode_bgr self.decode_bgr = decode_bgr
self.decode_sar = decode_sar self.decode_sar = decode_sar
self.read_geo_info = False
def read_img(self, img_path): def read_img(self, img_path):
img_format = imghdr.what(img_path) img_format = imghdr.what(img_path)
name, ext = os.path.splitext(img_path) name, ext = os.path.splitext(img_path)
geo_trans, geo_proj = None, None
if img_format == 'tiff' or ext == '.img': if img_format == 'tiff' or ext == '.img':
try: try:
import gdal import gdal
@ -209,7 +215,7 @@ class DecodeImg(Transform):
dataset = gdal.Open(img_path) dataset = gdal.Open(img_path)
if dataset == None: if dataset == None:
raise IOError('Can not open', img_path) raise IOError('Cannot open', img_path)
im_data = dataset.ReadAsArray() im_data = dataset.ReadAsArray()
if im_data.ndim == 2 and self.decode_sar: if im_data.ndim == 2 and self.decode_sar:
im_data = to_intensity(im_data) # is read SAR im_data = to_intensity(im_data) # is read SAR
@ -217,26 +223,38 @@ class DecodeImg(Transform):
else: else:
if im_data.ndim == 3: if im_data.ndim == 3:
im_data = im_data.transpose((1, 2, 0)) im_data = im_data.transpose((1, 2, 0))
return im_data if self.read_geo_info:
geo_trans = dataset.GetGeoTransform()
geo_proj = dataset.GetGeoProjection()
elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
if self.decode_bgr: if self.decode_bgr:
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR) cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
else: else:
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR) cv2.IMREAD_ANYCOLOR)
elif ext == '.npy': elif ext == '.npy':
return np.load(img_path) im_data = np.load(img_path)
else: else:
raise TypeError("Image format {} is not supported!".format(ext)) raise TypeError("Image format {} is not supported!".format(ext))
if self.read_geo_info:
return im_data, geo_trans, geo_proj
else:
return im_data
def apply_im(self, im_path): def apply_im(self, im_path):
if isinstance(im_path, str): if isinstance(im_path, str):
try: try:
image = self.read_img(im_path) data = self.read_img(im_path)
except: except:
raise ValueError("Cannot read the image file {}!".format( raise ValueError("Cannot read the image file {}!".format(
im_path)) im_path))
if self.read_geo_info:
image, geo_trans, geo_proj = data
geo_info_dict = {'geo_trans': geo_trans, 'geo_proj': geo_proj}
else:
image = data
else: else:
image = im_path image = im_path
@ -246,7 +264,10 @@ class DecodeImg(Transform):
if self.to_uint8: if self.to_uint8:
image = to_uint8(image) image = to_uint8(image)
return image if self.read_geo_info:
return image, geo_info_dict
else:
return image
def apply_mask(self, mask): def apply_mask(self, mask):
try: try:
@ -269,15 +290,37 @@ class DecodeImg(Transform):
""" """
if 'image' in sample: if 'image' in sample:
sample['image_ori'] = copy.deepcopy(sample['image']) if self.read_geo_info:
sample['image'] = self.apply_im(sample['image']) image, geo_info_dict = self.apply_im(sample['image'])
sample['image'] = image
sample['geo_info_dict'] = geo_info_dict
else:
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample: if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2']) if self.read_geo_info:
image2, geo_info_dict2 = self.apply_im(sample['image2'])
sample['image2'] = image2
sample['geo_info_dict2'] = geo_info_dict2
else:
sample['image2'] = self.apply_im(sample['image2'])
if 'image_t1' in sample and not 'image' in sample: if 'image_t1' in sample and not 'image' in sample:
if not ('image_t2' in sample and 'image2' not in sample): if not ('image_t2' in sample and 'image2' not in sample):
raise ValueError raise ValueError
sample['image'] = self.apply_im(sample['image_t1']) if self.read_geo_info:
sample['image2'] = self.apply_im(sample['image_t2']) image, geo_info_dict = self.apply_im(sample['image_t1'])
sample['image'] = image
sample['geo_info_dict'] = geo_info_dict
else:
sample['image'] = self.apply_im(sample['image_t1'])
if self.read_geo_info:
image2, geo_info_dict2 = self.apply_im(sample['image_t2'])
sample['image2'] = image2
sample['geo_info_dict2'] = geo_info_dict2
else:
sample['image2'] = self.apply_im(sample['image_t2'])
if 'mask' in sample: if 'mask' in sample:
sample['mask_ori'] = copy.deepcopy(sample['mask']) sample['mask_ori'] = copy.deepcopy(sample['mask'])
sample['mask'] = self.apply_mask(sample['mask']) sample['mask'] = self.apply_mask(sample['mask'])
@ -286,6 +329,7 @@ class DecodeImg(Transform):
if im_height != se_height or im_width != se_width: if im_height != se_height or im_width != se_width:
raise ValueError( raise ValueError(
"The height or width of the image is not same as the mask.") "The height or width of the image is not same as the mask.")
if 'aux_masks' in sample: if 'aux_masks' in sample:
sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks']) sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks'])
sample['aux_masks'] = list( sample['aux_masks'] = list(
@ -295,6 +339,7 @@ class DecodeImg(Transform):
sample['im_shape'] = np.array( sample['im_shape'] = np.array(
sample['image'].shape[:2], dtype=np.float32) sample['image'].shape[:2], dtype=np.float32)
sample['scale_factor'] = np.array([1., 1.], dtype=np.float32) sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return sample return sample

Loading…
Cancel
Save