[Feature] Add deployment code and docs (#43)

own
Lin Manhui 3 years ago committed by GitHub
parent 1d92bb5f10
commit 7ab3e65a12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 33
      deploy/README.md
  2. 62
      deploy/export/README.md
  3. 59
      deploy/export/export_model.py
  4. 2
      paddlers/__init__.py
  5. 22
      paddlers/custom_models/cd/bit.py
  6. 8
      paddlers/custom_models/cd/changestar.py
  7. 4
      paddlers/custom_models/cd/dsamnet.py
  8. 4
      paddlers/custom_models/cd/dsifn.py
  9. 6
      paddlers/custom_models/cd/layers/attention.py
  10. 16
      paddlers/custom_models/cd/layers/blocks.py
  11. 6
      paddlers/custom_models/cd/snunet.py
  12. 14
      paddlers/custom_models/cd/stanet.py
  13. 1
      paddlers/deploy/__init__.py
  14. 283
      paddlers/deploy/predictor.py
  15. 11
      paddlers/tasks/base.py
  16. 28
      paddlers/tasks/changedetector.py
  17. 6
      paddlers/tasks/load_model.py
  18. 13
      paddlers/tasks/utils/infer_nets.py

@ -0,0 +1,33 @@
# Python部署
PaddleRS已经集成了基于Python的高性能预测(prediction)接口。在安装PaddleRS后,可参照如下代码示例执行预测。
## 部署模型导出
在服务端部署模型时需要首先将训练过程中保存的模型导出为部署格式,具体的导出步骤请参考文档[部署模型导出](/deploy/export/README.md)。
## 预测接口调用
* **基本使用**
以下是一个调用PaddleRS Python预测接口的实例。首先构建`Predictor`对象,然后调用`Predictor`的`predict()`方法执行预测。
```python
import paddlers as pdrs
# 将导出模型所在目录传入Predictor的构造方法中
predictor = pdrs.deploy.Predictor('./inference_model')
# img_file参数指定输入图像路径
result = predictor.predict(img_file='test.jpg')
```
* **在预测过程中评估模型预测速度**
加载模型后,对前几张图片的预测速度会较慢,这是因为程序刚启动时需要进行内存、显存初始化等步骤。通常,在处理20-30张图片后,模型的预测速度能够达到稳定值。基于这一观察,**如果需要评估模型的预测速度,可通过指定预热轮数`warmup_iters`对模型进行预热**。此外,**为获得更加精准的预测速度估计值,可指定重复`repeats`次预测后计算平均耗时**。
```python
import paddlers as pdrs
predictor = pdrs.deploy.Predictor('./inference_model')
result = predictor.predict(img_file='test.jpg',
warmup_iters=100,
repeats=100)
```

@ -0,0 +1,62 @@
# 部署模型导出
## 目录
* [模型格式说明](#1)
* [训练模型格式](#11)
* [部署模型格式](#12)
* [部署模型导出](#2)
## <h2 id="1">模型格式说明</h2>
### <h3 id="11">训练模型格式</h3>
使用PaddleRS训练模型,输出目录中主要包含四个文件:
-`model.pdopt`,包含训练过程中使用到的优化器的状态参数;
-`model.pdparams`,包含模型的权重参数;
-`model.yml`,模型的配置文件(包括预处理参数、模型规格参数等);
-`eval_details.json`,包含验证阶段模型取得的指标。
需要注意的是,由于训练阶段使用模型的动态图版本,因此将上述格式的模型权重参数和配置文件直接用于部署往往效率不高。本项目建议将模型导出为专用的部署格式,在部署阶段使用静态图版本的模型以达到更高的推理效率。
### <h3 id="12">部署模型格式</h3>
在服务端部署模型时,需要将训练过程中保存的模型导出为专用的格式。具体而言,在部署阶段,使用下述五个文件描述训练好的模型:
-`model.pdmodel`,记录模型的网络结构;
-`model.pdiparams`,包含模型权重参数;
-`model.pdiparams.info`,包含模型权重名称;
-`model.yml`,模型的配置文件(包括预处理参数、模型规格参数等);
-`pipeline.yml`,流程配置文件。
## <h2 id="2">部署模型导出</h2>
使用如下指令导出部署格式的模型:
```commandline
python deploy/export/export_model.py --model_dir=./output/deeplabv3p/best_model/ --save_dir=./inference_model/
```
其中,`--model_dir`选项和`--save_dir`选项分别指定存储训练格式模型和部署格式模型的目录。例如,在上面的例子中,`./inference_model/`目录下将生成`model.pdmodel`、`model.pdiparams`、`model.pdiparams.info`、`model.yml`和`pipeline.yml`五个文件。
`deploy/export/export_model.py`脚本包含三个命令行选项:
| 参数 | 说明 |
| ---- | ---- |
| --model_dir | 待导出的训练格式模型存储路径,例如`./output/deeplabv3p/best_model/`。 |
| --save_dir | 导出的部署格式模型存储路径,例如`./inference_model/`。 |
| --fixed_input_shape | 固定导出模型的输入张量形状。默认值为None,表示使用任务默认输入张量形状。 |
当使用TensorRT执行模型推理时,需固定模型的输入张量形状。此时,可通过`--fixed_input_shape`选项来指定输入形状,具体有两种形式:`[w,h]`或者`[n,c,w,h]`。例如,指定`--fixed_input_shape`为`[224,224]`时,实际的输入张量形状可视为`[-1,3,224,224]`(-1表示可以为任意正整数,通道数默认为3);若想同时固定输入数据在batch维度的大小为1、通道数为4,则可将该选项设置为`[1,4,224,224]`。
完整命令示例:
```commandline
python deploy/export_model.py --model_dir=./output/deeplabv3p/best_model/ --save_dir=./inference_model/ --fixed_input_shape=[224,224]
```
对于`--fixed_input_shape`选项,**请注意**:
-在推理阶段若需固定分类模型的输入形状,请保持其与训练阶段的输入形状一致。
-对于检测模型中的YOLO/PPYOLO系列模型,请保证输入影像的`w`和`h`有相同取值、且均为32的倍数;指定`--fixed_input_shape`时,R-CNN模型的`w`和`h`也均需为32的倍数。
-指定`[w,h]`时,请使用半角逗号(`,`)分隔`w`和`h`,二者之间不允许存在空格等其它字符。
-将`w`和`h`设得越大,则模型在推理过程中的耗时和内存/显存占用越高。不过,如果`w`和`h`过小,则可能对模型的精度存在较大负面影响。

@ -0,0 +1,59 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from ast import literal_eval
from paddlers.tasks import load_model
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', '-m', type=str, default=None, help='model directory path')
parser.add_argument('--save_dir', '-s', type=str, default=None, help='path to save inference model')
parser.add_argument('--fixed_input_shape', '-fs', type=str, default=None,
help="export inference model with fixed input shape: [w,h] or [n,c,w,h]")
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
# Get input shape
fixed_input_shape = None
if args.fixed_input_shape is not None:
# Try to interpret the string as a list.
fixed_input_shape = literal_eval(args.fixed_input_shape)
# Check validaty
if not isinstance(fixed_input_shape, list):
raise ValueError("fixed_input_shape should be of None or list type.")
if len(fixed_input_shape) not in (2, 4):
raise ValueError("fixed_input_shape contains an incorrect number of elements.")
if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0:
raise ValueError("the input width and height must be positive integers.")
if len(fixed_input_shape)==4 and fixed_input_shape[1] <= 0:
raise ValueError("the number of input channels must be a positive integer.")
# Set environment variables
os.environ['PADDLEX_EXPORT_STAGE'] = 'True'
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
# Load model from directory
model = load_model(args.model_dir)
# Do dynamic-to-static cast
# XXX: Invoke a protected (single underscore) method outside of subclasses.
model._export_inference_model(args.save_dir, fixed_input_shape)

@ -21,4 +21,4 @@ env_info = get_environ_info()
log_level = 2
from . import tasks, datasets, transforms, utils, tools, models
from . import tasks, datasets, transforms, utils, tools, models, deploy

@ -71,7 +71,7 @@ class BIT(nn.Layer):
dec_depth=8,
dec_head_dim=8,
**backbone_kwargs):
super().__init__()
super(BIT, self).__init__()
# TODO: reduce hard-coded parameters
DIM = 32
@ -197,7 +197,7 @@ class BIT(nn.Layer):
class Residual(nn.Layer):
def __init__(self, fn):
super().__init__()
super(Residual, self).__init__()
self.fn = fn
def forward(self, x, **kwargs):
@ -206,7 +206,7 @@ class Residual(nn.Layer):
class Residual2(nn.Layer):
def __init__(self, fn):
super().__init__()
super(Residual2, self).__init__()
self.fn = fn
def forward(self, x1, x2, **kwargs):
@ -215,7 +215,7 @@ class Residual2(nn.Layer):
class PreNorm(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
super(PreNorm, self).__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
@ -225,7 +225,7 @@ class PreNorm(nn.Layer):
class PreNorm2(nn.Layer):
def __init__(self, dim, fn):
super().__init__()
super(PreNorm2, self).__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
@ -235,7 +235,7 @@ class PreNorm2(nn.Layer):
class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout_rate=0.):
super().__init__(
super(FeedForward, self).__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
@ -249,7 +249,7 @@ class CrossAttention(nn.Layer):
head_dim=64,
dropout_rate=0.,
apply_softmax=True):
super().__init__()
super(CrossAttention, self).__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
@ -288,12 +288,12 @@ class CrossAttention(nn.Layer):
class SelfAttention(CrossAttention):
def forward(self, x):
return super().forward(x, x)
return super(SelfAttention, self).forward(x, x)
class TransformerEncoder(nn.Layer):
def __init__(self, dim, depth, n_heads, head_dim, mlp_dim, dropout_rate):
super().__init__()
super(TransformerEncoder, self).__init__()
self.layers = nn.LayerList([])
for _ in range(depth):
self.layers.append(
@ -322,7 +322,7 @@ class TransformerDecoder(nn.Layer):
mlp_dim,
dropout_rate,
apply_softmax=True):
super().__init__()
super(TransformerDecoder, self).__init__()
self.layers = nn.LayerList([])
for _ in range(depth):
self.layers.append(
@ -349,7 +349,7 @@ class Backbone(nn.Layer, KaimingInitMixin):
arch='resnet18',
pretrained=True,
n_stages=5):
super().__init__()
super(Backbone, self).__init__()
expand = 1
strides = (2, 1, 2, 1, 1)

@ -28,7 +28,7 @@ class _ChangeStarBase(nn.Layer):
def __init__(self, seg_model, num_classes, mid_channels, inner_channels,
num_convs, scale_factor):
super().__init__()
super(_ChangeStarBase, self).__init__(_ChangeStarBase, self)
self.extract = seg_model
self.detect = ChangeMixin(
@ -63,7 +63,7 @@ class _ChangeStarBase(nn.Layer):
class ChangeMixin(nn.Layer):
def __init__(self, in_ch, out_ch, mid_ch, num_convs, scale_factor):
super().__init__()
super(ChangeMixin, self).__init__(ChangeMixin, self)
convs = [Conv3x3(in_ch, mid_ch, norm=True, act=True)]
convs += [
Conv3x3(
@ -112,7 +112,7 @@ class ChangeStar_FarSeg(_ChangeStarBase):
# TODO: Configurable FarSeg model
class _FarSegWrapper(nn.Layer):
def __init__(self, seg_model):
super().__init__()
super(_FarSegWrapper, self).__init__()
self._seg_model = seg_model
self._seg_model.cls_pred_conv = Identity()
@ -131,7 +131,7 @@ class ChangeStar_FarSeg(_ChangeStarBase):
seg_model = FarSeg(out_ch=mid_channels)
super().__init__(
super(ChangeStar_FarSeg, self).__init__(
seg_model=_FarSegWrapper(seg_model),
num_classes=num_classes,
mid_channels=mid_channels,

@ -41,7 +41,7 @@ class DSAMNet(nn.Layer):
"""
def __init__(self, in_channels, num_classes, ca_ratio=8, sa_kernel=7):
super().__init__()
super(DSAMNet, self).__init__()
WIDTH = 64
@ -90,7 +90,7 @@ class DSAMNet(nn.Layer):
class DSLayer(nn.Sequential):
def __init__(self, in_ch, out_ch, itm_ch, **convd_kwargs):
super().__init__(
super(DSLayer, self).__init__(
nn.Conv2DTranspose(
in_ch, itm_ch, kernel_size=3, padding=1, **convd_kwargs),
make_norm(itm_ch),

@ -41,7 +41,7 @@ class DSIFN(nn.Layer):
"""
def __init__(self, num_classes, use_dropout=False):
super().__init__()
super(DSIFN, self).__init__()
self.encoder1 = self.encoder2 = VGG16FeaturePicker()
@ -191,7 +191,7 @@ class DSIFN(nn.Layer):
class VGG16FeaturePicker(nn.Layer):
def __init__(self, indices=(3, 8, 15, 22, 29)):
super().__init__()
super(VGG16FeaturePicker, self).__init__()
features = list(vgg16(pretrained=True).features)[:30]
self.features = nn.LayerList(features)
self.features.eval()

@ -33,7 +33,7 @@ class ChannelAttention(nn.Layer):
"""
def __init__(self, in_ch, ratio=8):
super().__init__()
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.max_pool = nn.AdaptiveMaxPool2D(1)
self.fc1 = Conv1x1(in_ch, in_ch // ratio, bias=False, act=True)
@ -59,7 +59,7 @@ class SpatialAttention(nn.Layer):
"""
def __init__(self, kernel_size=7):
super().__init__()
super(SpatialAttention, self).__init__()
self.conv = BasicConv(2, 1, kernel_size, bias=False)
def forward(self, x):
@ -85,7 +85,7 @@ class CBAM(nn.Layer):
"""
def __init__(self, in_ch, ratio=8, kernel_size=7):
super().__init__()
super(CBAM, self).__init__()
self.ca = ChannelAttention(in_ch, ratio=ratio)
self.sa = SpatialAttention(kernel_size=kernel_size)

@ -51,7 +51,7 @@ class BasicConv(nn.Layer):
norm=False,
act=False,
**kwargs):
super().__init__()
super(BasicConv, self).__init__()
seq = []
if kernel_size >= 2:
seq.append(nn.Pad2D(kernel_size // 2, mode=pad_mode))
@ -87,7 +87,7 @@ class Conv1x1(BasicConv):
norm=False,
act=False,
**kwargs):
super().__init__(
super(Conv1x1, self).__init__(
in_ch,
out_ch,
1,
@ -107,7 +107,7 @@ class Conv3x3(BasicConv):
norm=False,
act=False,
**kwargs):
super().__init__(
super(Conv3x3, self).__init__(
in_ch,
out_ch,
3,
@ -127,7 +127,7 @@ class Conv7x7(BasicConv):
norm=False,
act=False,
**kwargs):
super().__init__(
super(Conv7x7, self).__init__(
in_ch,
out_ch,
7,
@ -140,12 +140,12 @@ class Conv7x7(BasicConv):
class MaxPool2x2(nn.MaxPool2D):
def __init__(self, **kwargs):
super().__init__(kernel_size=2, stride=(2, 2), padding=(0, 0), **kwargs)
super(MaxPool2x2, self).__init__(kernel_size=2, stride=(2, 2), padding=(0, 0), **kwargs)
class MaxUnPool2x2(nn.MaxUnPool2D):
def __init__(self, **kwargs):
super().__init__(kernel_size=2, stride=(2, 2), padding=(0, 0), **kwargs)
super(MaxUnPool2x2, self).__init__(kernel_size=2, stride=(2, 2), padding=(0, 0), **kwargs)
class ConvTransposed3x3(nn.Layer):
@ -156,7 +156,7 @@ class ConvTransposed3x3(nn.Layer):
norm=False,
act=False,
**kwargs):
super().__init__()
super(ConvTransposed3x3, self).__init__()
seq = []
seq.append(
nn.Conv2DTranspose(
@ -185,7 +185,7 @@ class Identity(nn.Layer):
"""A placeholder identity operator that accepts exactly one argument."""
def __init__(self, *args, **kwargs):
super().__init__()
super(Identity, self).__init__()
def forward(self, x):
return x

@ -39,7 +39,7 @@ class SNUNet(nn.Layer, KaimingInitMixin):
"""
def __init__(self, in_channels, num_classes, width=32):
super().__init__()
super(SNUNet, self).__init__()
filters = (width, width * 2, width * 4, width * 8, width * 16)
@ -142,7 +142,7 @@ class SNUNet(nn.Layer, KaimingInitMixin):
class ConvBlockNested(nn.Layer):
def __init__(self, in_ch, out_ch, mid_ch):
super().__init__()
super(ConvBlockNested, self).__init__()
self.act = nn.ReLU()
self.conv1 = nn.Conv2D(in_ch, mid_ch, kernel_size=3, padding=1)
self.bn1 = make_norm(mid_ch)
@ -163,7 +163,7 @@ class ConvBlockNested(nn.Layer):
class Up(nn.Layer):
def __init__(self, in_ch, use_conv=False):
super().__init__()
super(Up, self).__init__()
if use_conv:
self.up = nn.Conv2DTranspose(in_ch, in_ch, 2, stride=2)
else:

@ -46,7 +46,7 @@ class STANet(nn.Layer):
"""
def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1):
super().__init__()
super(STANet, self).__init__()
WIDTH = 64
@ -94,7 +94,7 @@ def build_sta_module(in_ch, att_type, ds):
class Backbone(nn.Layer, KaimingInitMixin):
def __init__(self, in_ch, arch, pretrained=True, strides=(2, 1, 2, 2, 2)):
super().__init__()
super(Backbone, self).__init__()
if arch == 'resnet18':
self.resnet = resnet.resnet18(
@ -148,7 +148,7 @@ class Backbone(nn.Layer, KaimingInitMixin):
class Decoder(nn.Layer, KaimingInitMixin):
def __init__(self, f_ch):
super().__init__()
super(Decoder, self).__init__()
self.dr1 = Conv1x1(64, 96, norm=True, act=True)
self.dr2 = Conv1x1(128, 96, norm=True, act=True)
self.dr3 = Conv1x1(256, 96, norm=True, act=True)
@ -183,7 +183,7 @@ class Decoder(nn.Layer, KaimingInitMixin):
class BAM(nn.Layer):
def __init__(self, in_ch, ds):
super().__init__()
super(BAM, self).__init__()
self.ds = ds
self.pool = nn.AvgPool2D(self.ds)
@ -220,7 +220,7 @@ class BAM(nn.Layer):
class PAMBlock(nn.Layer):
def __init__(self, in_ch, scale=1, ds=1):
super().__init__()
super(PAMBlock, self).__init__()
self.scale = scale
self.ds = ds
@ -280,7 +280,7 @@ class PAMBlock(nn.Layer):
class PAM(nn.Layer):
def __init__(self, in_ch, ds, scales=(1, 2, 4, 8)):
super().__init__()
super(PAM, self).__init__()
self.stages = nn.LayerList(
[PAMBlock(
@ -296,7 +296,7 @@ class PAM(nn.Layer):
class Attention(nn.Layer):
def __init__(self, att):
super().__init__()
super(Attention, self).__init__()
self.att = att
def forward(self, x1, x2):

@ -0,0 +1 @@
from .predictor import Predictor

@ -0,0 +1,283 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import numpy as np
from paddle.inference import Config
from paddle.inference import create_predictor
from paddle.inference import PrecisionType
from paddlers.tasks import load_model
from paddlers.utils import logging, Timer
class Predictor(object):
def __init__(self,
model_dir,
use_gpu=False,
gpu_id=0,
cpu_thread_num=1,
use_mkl=True,
mkl_thread_num=4,
use_trt=False,
use_glog=False,
memory_optimize=True,
max_trt_batch_size=1,
trt_precision_mode='float32'):
"""
创建Paddle Predictor
Args:
model_dir: 模型路径必须是导出的部署或量化模型
use_gpu: 是否使用GPU默认为False
gpu_id: 使用GPU的ID默认为0
cpu_thread_num使用cpu进行预测时的线程数默认为1
use_mkl: 是否使用mkldnn计算库CPU情况下使用默认为False
mkl_thread_num: mkldnn计算线程数默认为4
use_trt: 是否使用TensorRT默认为False
use_glog: 是否启用glog日志, 默认为False
memory_optimize: 是否启动内存优化默认为True
max_trt_batch_size: 在使用TensorRT时配置的最大batch size默认为1
trt_precision_mode在使用TensorRT时采用的精度可选值['float32', 'float16']默认为'float32'
"""
self.model_dir = model_dir
self._model = load_model(model_dir, with_net=False)
if trt_precision_mode.lower() == 'float32':
trt_precision_mode = PrecisionType.Float32
elif trt_precision_mode.lower() == 'float16':
trt_precision_mode = PrecisionType.Float16
else:
logging.error(
"TensorRT precision mode {} is invalid. Supported modes are float32 and float16."
.format(trt_precision_mode),
exit=True)
self.predictor = self.create_predictor(
use_gpu=use_gpu,
gpu_id=gpu_id,
cpu_thread_num=cpu_thread_num,
use_mkl=use_mkl,
mkl_thread_num=mkl_thread_num,
use_trt=use_trt,
use_glog=use_glog,
memory_optimize=memory_optimize,
max_trt_batch_size=max_trt_batch_size,
trt_precision_mode=trt_precision_mode)
self.timer = Timer()
def create_predictor(self,
use_gpu=True,
gpu_id=0,
cpu_thread_num=1,
use_mkl=True,
mkl_thread_num=4,
use_trt=False,
use_glog=False,
memory_optimize=True,
max_trt_batch_size=1,
trt_precision_mode=PrecisionType.Float32):
config = Config(
osp.join(self.model_dir, 'model.pdmodel'),
osp.join(self.model_dir, 'model.pdiparams'))
if use_gpu:
# 设置GPU初始显存(单位M)和Device ID
config.enable_use_gpu(200, gpu_id)
config.switch_ir_optim(True)
if use_trt:
if self._model.model_type == 'segmenter':
logging.warning(
"Semantic segmentation models do not support TensorRT acceleration, "
"TensorRT is forcibly disabled.")
elif 'RCNN' in self._model.__class__.__name__:
logging.warning(
"RCNN models do not support TensorRT acceleration, "
"TensorRT is forcibly disabled.")
else:
config.enable_tensorrt_engine(
workspace_size=1 << 10,
max_batch_size=max_trt_batch_size,
min_subgraph_size=3,
precision_mode=trt_precision_mode,
use_static=False,
use_calib_mode=False)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_thread_num)
if use_mkl:
if self._model.__class__.__name__ == 'MaskRCNN':
logging.warning(
"MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled"
)
else:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(mkl_thread_num)
except Exception as e:
logging.warning(
"The current environment does not support MKL-DNN, MKL-DNN is disabled."
)
pass
if not use_glog:
config.disable_glog_info()
if memory_optimize:
config.enable_memory_optim()
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor
def preprocess(self, images, transforms):
preprocessed_samples = self._model._preprocess(
images, transforms, to_tensor=False)
if self._model.model_type == 'classifier':
preprocessed_samples = {'image': preprocessed_samples[0]}
elif self._model.model_type == 'segmenter':
preprocessed_samples = {
'image': preprocessed_samples[0],
'ori_shape': preprocessed_samples[1]
}
elif self._model.model_type == 'detector':
pass
elif self._model.model_type == 'changedetector':
preprocessed_samples = {
'image': preprocessed_samples[0],
'image2': preprocessed_samples[1],
'ori_shape': preprocessed_samples[2]
}
else:
logging.error(
"Invalid model type {}".format(self._model.model_type),
exit=True)
return preprocessed_samples
def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None):
if self._model.model_type == 'classifier':
true_topk = min(self._model.num_classes, topk)
preds = self._model._postprocess(net_outputs[0], true_topk)
elif self._model.model_type in ('segmenter', 'changedetector'):
label_map, score_map = self._model._postprocess(
net_outputs,
batch_origin_shape=ori_shape,
transforms=transforms.transforms)
preds = [{
'label_map': l,
'score_map': s
} for l, s in zip(label_map, score_map)]
elif self._model.model_type == 'detector':
net_outputs = {
k: v
for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
}
preds = self._model._postprocess(net_outputs)
else:
logging.error(
"Invalid model type {}.".format(self._model.model_type),
exit=True)
return preds
def raw_predict(self, inputs):
""" 接受预处理过后的数据进行预测
Args:
inputs(dict): 预处理过后的数据
"""
input_names = self.predictor.get_input_names()
for name in input_names:
input_tensor = self.predictor.get_input_handle(name)
input_tensor.copy_from_cpu(inputs[name])
self.predictor.run()
output_names = self.predictor.get_output_names()
net_outputs = list()
for name in output_names:
output_tensor = self.predictor.get_output_handle(name)
net_outputs.append(output_tensor.copy_to_cpu())
return net_outputs
def _run(self, images, topk=1, transforms=None):
self.timer.preprocess_time_s.start()
preprocessed_input = self.preprocess(images, transforms)
self.timer.preprocess_time_s.end(iter_num=len(images))
self.timer.inference_time_s.start()
net_outputs = self.raw_predict(preprocessed_input)
self.timer.inference_time_s.end(iter_num=1)
self.timer.postprocess_time_s.start()
results = self.postprocess(
net_outputs,
topk,
ori_shape=preprocessed_input.get('ori_shape', None),
transforms=transforms)
self.timer.postprocess_time_s.end(iter_num=len(images))
return results
def predict(self,
img_file,
topk=1,
transforms=None,
warmup_iters=0,
repeats=1):
""" 图片预测
Args:
img_file(List[np.ndarray or str], str or np.ndarray):
对于场景分类图像复原目标检测和语义分割任务来说该参数可为单一图像路径或是解码后的排列格式为H, W, C
且具有float32类型的BGR图像表示为numpy的ndarray形式或者是一组图像路径或np.ndarray对象构成的列表对于变化检测
任务来说该参数可以为图像路径二元组分别表示前后两个时相影像路径或是两幅图像组成的二元组或者是上述两种二元组
之一构成的列表
topk(int): 场景分类模型预测时使用表示预测前topk的结果默认值为1
transforms (paddlex.transforms): 数据预处理操作默认值为None, 即使用`model.yml`中保存的数据预处理操作
warmup_iters (int): 预热轮数用于评估模型推理以及前后处理速度若大于1会预先重复预测warmup_iters而后才开始正式的预测及其速度评估默认为0
repeats (int): 重复次数用于评估模型推理以及前后处理速度若大于1会预测repeats次取时间平均值默认值为1
"""
if repeats < 1:
logging.error("`repeats` must be greater than 1.", exit=True)
if transforms is None and not hasattr(self._model, 'test_transforms'):
raise Exception("Transforms need to be defined, now is None.")
if transforms is None:
transforms = self._model.test_transforms
if isinstance(img_file, tuple) and len(img_file) != 2:
raise ValueError(
f"A change detection model accepts exactly two input images, but there are {len(img_file)}."
)
if isinstance(img_file, (str, np.ndarray, tuple)):
images = [img_file]
else:
images = img_file
for _ in range(warmup_iters):
self._run(images=images, topk=topk, transforms=transforms)
self.timer.reset()
for _ in range(repeats):
results = self._run(images=images, topk=topk, transforms=transforms)
self.timer.repeats = repeats
self.timer.img_num = len(images)
self.timer.info(average=True)
if isinstance(img_file, (str, np.ndarray)):
results = results[0]
return results
def batch_predict(self, image_list, **params):
return self.predict(img_file=image_list, **params)

@ -33,7 +33,7 @@ from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
_get_shared_memory_size_in_M, EarlyStop)
import paddlers.utils.logging as logging
from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
from .utils.infer_nets import InferNet
from .utils.infer_nets import InferNet, InferCDNet
class BaseModel:
@ -580,13 +580,16 @@ class BaseModel:
return pipeline_info
def _build_inference_net(self):
infer_net = self.net if self.model_type == 'detector' else InferNet(
self.net, self.model_type)
if self.model_type == 'detector':
infer_net = self.net
elif self.model_type == 'changedetector':
infer_net = InferCDNet(self.net)
else:
infer_net = InferNet(self.net, self.model_type)
infer_net.eval()
return infer_net
def _export_inference_model(self, save_dir, image_shape=None):
save_dir = osp.join(save_dir, 'inference_model')
self.test_inputs = self._get_test_inputs(image_shape)
infer_net = self._build_inference_net()

@ -98,11 +98,11 @@ class BaseChangeDetector(BaseModel):
else:
image_shape = [None, 3, -1, -1]
self.fixed_input_shape = image_shape
input_spec = [
return [
InputSpec(
shape=image_shape, name='image', dtype='float32')
shape=image_shape, name='image', dtype='float32'), InputSpec(
shape=image_shape, name='image2', dtype='float32')
]
return input_spec
def run(self, net, inputs, mode):
net_out = net(inputs[0], inputs[1])
@ -532,22 +532,26 @@ class BaseChangeDetector(BaseModel):
def _preprocess(self, images, transforms, to_tensor=True):
arrange_transforms(
model_type=self.model_type, transforms=transforms, mode='test')
batch_im = list()
batch_im1, batch_im2 = list(), list()
batch_ori_shape = list()
for im in images:
sample = {'image': im}
if isinstance(sample['image'], str):
for im1, im2 in images:
sample = {'image_t1': im1, 'image_t2': im2}
if isinstance(sample['image_t1'], str) or \
isinstance(sample['image_t2'], str):
sample = ImgDecoder(to_rgb=False)(sample)
ori_shape = sample['image'].shape[:2]
im = transforms(sample)[0]
batch_im.append(im)
im1, im2 = transforms(sample)[:2]
batch_im1.append(im1)
batch_im2.append(im2)
batch_ori_shape.append(ori_shape)
if to_tensor:
batch_im = paddle.to_tensor(batch_im)
batch_im1 = paddle.to_tensor(batch_im1)
batch_im2 = paddle.to_tensor(batch_im2)
else:
batch_im = np.asarray(batch_im)
batch_im1 = np.asarray(batch_im1)
batch_im2 = np.asarray(batch_im2)
return batch_im, batch_ori_shape
return batch_im1, batch_im2, batch_ori_shape
@staticmethod
def get_transforms_shape_info(batch_ori_shape, transforms):

@ -61,12 +61,6 @@ def load_model(model_dir, **params):
model_info = yaml.load(f.read(), Loader=yaml.Loader)
f.close()
version = model_info['version']
if int(version.split('.')[0]) < 2:
raise Exception(
'Current version is {}, a model trained by PaddleRS={} cannot be load.'.
format(paddlers.__version__, version))
status = model_info['status']
with_net = params.get('with_net', True)
if not with_net:

@ -43,3 +43,16 @@ class InferNet(paddle.nn.Layer):
outputs = self.postprocessor(net_outputs)
return outputs
class InferCDNet(paddle.nn.Layer):
def __init__(self, net):
super(InferCDNet, self).__init__()
self.net = net
self.postprocessor = PostProcessor('changedetector')
def forward(self, x1, x2):
net_outputs = self.net(x1, x2)
outputs = self.postprocessor(net_outputs)
return outputs

Loading…
Cancel
Save