[Feature] Add deployment code and docs (#43)
parent
1d92bb5f10
commit
7ab3e65a12
18 changed files with 515 additions and 63 deletions
@ -0,0 +1,33 @@ |
||||
# Python部署 |
||||
|
||||
PaddleRS已经集成了基于Python的高性能预测(prediction)接口。在安装PaddleRS后,可参照如下代码示例执行预测。 |
||||
|
||||
## 部署模型导出 |
||||
|
||||
在服务端部署模型时需要首先将训练过程中保存的模型导出为部署格式,具体的导出步骤请参考文档[部署模型导出](/deploy/export/README.md)。 |
||||
|
||||
## 预测接口调用 |
||||
|
||||
* **基本使用** |
||||
|
||||
以下是一个调用PaddleRS Python预测接口的实例。首先构建`Predictor`对象,然后调用`Predictor`的`predict()`方法执行预测。 |
||||
|
||||
```python |
||||
import paddlers as pdrs |
||||
# 将导出模型所在目录传入Predictor的构造方法中 |
||||
predictor = pdrs.deploy.Predictor('./inference_model') |
||||
# img_file参数指定输入图像路径 |
||||
result = predictor.predict(img_file='test.jpg') |
||||
``` |
||||
|
||||
* **在预测过程中评估模型预测速度** |
||||
|
||||
加载模型后,对前几张图片的预测速度会较慢,这是因为程序刚启动时需要进行内存、显存初始化等步骤。通常,在处理20-30张图片后,模型的预测速度能够达到稳定值。基于这一观察,**如果需要评估模型的预测速度,可通过指定预热轮数`warmup_iters`对模型进行预热**。此外,**为获得更加精准的预测速度估计值,可指定重复`repeats`次预测后计算平均耗时**。 |
||||
|
||||
```python |
||||
import paddlers as pdrs |
||||
predictor = pdrs.deploy.Predictor('./inference_model') |
||||
result = predictor.predict(img_file='test.jpg', |
||||
warmup_iters=100, |
||||
repeats=100) |
||||
``` |
@ -0,0 +1,62 @@ |
||||
# 部署模型导出 |
||||
|
||||
## 目录 |
||||
|
||||
* [模型格式说明](#1) |
||||
* [训练模型格式](#11) |
||||
* [部署模型格式](#12) |
||||
* [部署模型导出](#2) |
||||
|
||||
## <h2 id="1">模型格式说明</h2> |
||||
|
||||
### <h3 id="11">训练模型格式</h3> |
||||
|
||||
使用PaddleRS训练模型,输出目录中主要包含四个文件: |
||||
|
||||
-`model.pdopt`,包含训练过程中使用到的优化器的状态参数; |
||||
-`model.pdparams`,包含模型的权重参数; |
||||
-`model.yml`,模型的配置文件(包括预处理参数、模型规格参数等); |
||||
-`eval_details.json`,包含验证阶段模型取得的指标。 |
||||
|
||||
需要注意的是,由于训练阶段使用模型的动态图版本,因此将上述格式的模型权重参数和配置文件直接用于部署往往效率不高。本项目建议将模型导出为专用的部署格式,在部署阶段使用静态图版本的模型以达到更高的推理效率。 |
||||
|
||||
### <h3 id="12">部署模型格式</h3> |
||||
|
||||
在服务端部署模型时,需要将训练过程中保存的模型导出为专用的格式。具体而言,在部署阶段,使用下述五个文件描述训练好的模型: |
||||
-`model.pdmodel`,记录模型的网络结构; |
||||
-`model.pdiparams`,包含模型权重参数; |
||||
-`model.pdiparams.info`,包含模型权重名称; |
||||
-`model.yml`,模型的配置文件(包括预处理参数、模型规格参数等); |
||||
-`pipeline.yml`,流程配置文件。 |
||||
|
||||
## <h2 id="2">部署模型导出</h2> |
||||
|
||||
使用如下指令导出部署格式的模型: |
||||
|
||||
```commandline |
||||
python deploy/export/export_model.py --model_dir=./output/deeplabv3p/best_model/ --save_dir=./inference_model/ |
||||
``` |
||||
|
||||
其中,`--model_dir`选项和`--save_dir`选项分别指定存储训练格式模型和部署格式模型的目录。例如,在上面的例子中,`./inference_model/`目录下将生成`model.pdmodel`、`model.pdiparams`、`model.pdiparams.info`、`model.yml`和`pipeline.yml`五个文件。 |
||||
|
||||
`deploy/export/export_model.py`脚本包含三个命令行选项: |
||||
|
||||
| 参数 | 说明 | |
||||
| ---- | ---- | |
||||
| --model_dir | 待导出的训练格式模型存储路径,例如`./output/deeplabv3p/best_model/`。 | |
||||
| --save_dir | 导出的部署格式模型存储路径,例如`./inference_model/`。 | |
||||
| --fixed_input_shape | 固定导出模型的输入张量形状。默认值为None,表示使用任务默认输入张量形状。 | |
||||
|
||||
当使用TensorRT执行模型推理时,需固定模型的输入张量形状。此时,可通过`--fixed_input_shape`选项来指定输入形状,具体有两种形式:`[w,h]`或者`[n,c,w,h]`。例如,指定`--fixed_input_shape`为`[224,224]`时,实际的输入张量形状可视为`[-1,3,224,224]`(-1表示可以为任意正整数,通道数默认为3);若想同时固定输入数据在batch维度的大小为1、通道数为4,则可将该选项设置为`[1,4,224,224]`。 |
||||
|
||||
完整命令示例: |
||||
|
||||
```commandline |
||||
python deploy/export_model.py --model_dir=./output/deeplabv3p/best_model/ --save_dir=./inference_model/ --fixed_input_shape=[224,224] |
||||
``` |
||||
|
||||
对于`--fixed_input_shape`选项,**请注意**: |
||||
-在推理阶段若需固定分类模型的输入形状,请保持其与训练阶段的输入形状一致。 |
||||
-对于检测模型中的YOLO/PPYOLO系列模型,请保证输入影像的`w`和`h`有相同取值、且均为32的倍数;指定`--fixed_input_shape`时,R-CNN模型的`w`和`h`也均需为32的倍数。 |
||||
-指定`[w,h]`时,请使用半角逗号(`,`)分隔`w`和`h`,二者之间不允许存在空格等其它字符。 |
||||
-将`w`和`h`设得越大,则模型在推理过程中的耗时和内存/显存占用越高。不过,如果`w`和`h`过小,则可能对模型的精度存在较大负面影响。 |
@ -0,0 +1,59 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
import argparse |
||||
from ast import literal_eval |
||||
|
||||
from paddlers.tasks import load_model |
||||
|
||||
|
||||
def get_parser(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--model_dir', '-m', type=str, default=None, help='model directory path') |
||||
parser.add_argument('--save_dir', '-s', type=str, default=None, help='path to save inference model') |
||||
parser.add_argument('--fixed_input_shape', '-fs', type=str, default=None, |
||||
help="export inference model with fixed input shape: [w,h] or [n,c,w,h]") |
||||
return parser |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
parser = get_parser() |
||||
args = parser.parse_args() |
||||
|
||||
# Get input shape |
||||
fixed_input_shape = None |
||||
if args.fixed_input_shape is not None: |
||||
# Try to interpret the string as a list. |
||||
fixed_input_shape = literal_eval(args.fixed_input_shape) |
||||
# Check validaty |
||||
if not isinstance(fixed_input_shape, list): |
||||
raise ValueError("fixed_input_shape should be of None or list type.") |
||||
if len(fixed_input_shape) not in (2, 4): |
||||
raise ValueError("fixed_input_shape contains an incorrect number of elements.") |
||||
if fixed_input_shape[-1] <= 0 or fixed_input_shape[-2] <= 0: |
||||
raise ValueError("the input width and height must be positive integers.") |
||||
if len(fixed_input_shape)==4 and fixed_input_shape[1] <= 0: |
||||
raise ValueError("the number of input channels must be a positive integer.") |
||||
|
||||
# Set environment variables |
||||
os.environ['PADDLEX_EXPORT_STAGE'] = 'True' |
||||
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True' |
||||
|
||||
# Load model from directory |
||||
model = load_model(args.model_dir) |
||||
|
||||
# Do dynamic-to-static cast |
||||
# XXX: Invoke a protected (single underscore) method outside of subclasses. |
||||
model._export_inference_model(args.save_dir, fixed_input_shape) |
@ -0,0 +1 @@ |
||||
from .predictor import Predictor |
@ -0,0 +1,283 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os.path as osp |
||||
|
||||
import numpy as np |
||||
from paddle.inference import Config |
||||
from paddle.inference import create_predictor |
||||
from paddle.inference import PrecisionType |
||||
from paddlers.tasks import load_model |
||||
from paddlers.utils import logging, Timer |
||||
|
||||
|
||||
class Predictor(object): |
||||
def __init__(self, |
||||
model_dir, |
||||
use_gpu=False, |
||||
gpu_id=0, |
||||
cpu_thread_num=1, |
||||
use_mkl=True, |
||||
mkl_thread_num=4, |
||||
use_trt=False, |
||||
use_glog=False, |
||||
memory_optimize=True, |
||||
max_trt_batch_size=1, |
||||
trt_precision_mode='float32'): |
||||
""" |
||||
创建Paddle Predictor |
||||
|
||||
Args: |
||||
model_dir: 模型路径(必须是导出的部署或量化模型)。 |
||||
use_gpu: 是否使用GPU,默认为False。 |
||||
gpu_id: 使用GPU的ID,默认为0。 |
||||
cpu_thread_num:使用cpu进行预测时的线程数,默认为1。 |
||||
use_mkl: 是否使用mkldnn计算库,CPU情况下使用,默认为False。 |
||||
mkl_thread_num: mkldnn计算线程数,默认为4。 |
||||
use_trt: 是否使用TensorRT,默认为False。 |
||||
use_glog: 是否启用glog日志, 默认为False。 |
||||
memory_optimize: 是否启动内存优化,默认为True。 |
||||
max_trt_batch_size: 在使用TensorRT时配置的最大batch size,默认为1。 |
||||
trt_precision_mode:在使用TensorRT时采用的精度,可选值['float32', 'float16']。默认为'float32'。 |
||||
""" |
||||
|
||||
self.model_dir = model_dir |
||||
self._model = load_model(model_dir, with_net=False) |
||||
|
||||
if trt_precision_mode.lower() == 'float32': |
||||
trt_precision_mode = PrecisionType.Float32 |
||||
elif trt_precision_mode.lower() == 'float16': |
||||
trt_precision_mode = PrecisionType.Float16 |
||||
else: |
||||
logging.error( |
||||
"TensorRT precision mode {} is invalid. Supported modes are float32 and float16." |
||||
.format(trt_precision_mode), |
||||
exit=True) |
||||
|
||||
self.predictor = self.create_predictor( |
||||
use_gpu=use_gpu, |
||||
gpu_id=gpu_id, |
||||
cpu_thread_num=cpu_thread_num, |
||||
use_mkl=use_mkl, |
||||
mkl_thread_num=mkl_thread_num, |
||||
use_trt=use_trt, |
||||
use_glog=use_glog, |
||||
memory_optimize=memory_optimize, |
||||
max_trt_batch_size=max_trt_batch_size, |
||||
trt_precision_mode=trt_precision_mode) |
||||
self.timer = Timer() |
||||
|
||||
def create_predictor(self, |
||||
use_gpu=True, |
||||
gpu_id=0, |
||||
cpu_thread_num=1, |
||||
use_mkl=True, |
||||
mkl_thread_num=4, |
||||
use_trt=False, |
||||
use_glog=False, |
||||
memory_optimize=True, |
||||
max_trt_batch_size=1, |
||||
trt_precision_mode=PrecisionType.Float32): |
||||
config = Config( |
||||
osp.join(self.model_dir, 'model.pdmodel'), |
||||
osp.join(self.model_dir, 'model.pdiparams')) |
||||
|
||||
if use_gpu: |
||||
# 设置GPU初始显存(单位M)和Device ID |
||||
config.enable_use_gpu(200, gpu_id) |
||||
config.switch_ir_optim(True) |
||||
if use_trt: |
||||
if self._model.model_type == 'segmenter': |
||||
logging.warning( |
||||
"Semantic segmentation models do not support TensorRT acceleration, " |
||||
"TensorRT is forcibly disabled.") |
||||
elif 'RCNN' in self._model.__class__.__name__: |
||||
logging.warning( |
||||
"RCNN models do not support TensorRT acceleration, " |
||||
"TensorRT is forcibly disabled.") |
||||
else: |
||||
config.enable_tensorrt_engine( |
||||
workspace_size=1 << 10, |
||||
max_batch_size=max_trt_batch_size, |
||||
min_subgraph_size=3, |
||||
precision_mode=trt_precision_mode, |
||||
use_static=False, |
||||
use_calib_mode=False) |
||||
else: |
||||
config.disable_gpu() |
||||
config.set_cpu_math_library_num_threads(cpu_thread_num) |
||||
if use_mkl: |
||||
if self._model.__class__.__name__ == 'MaskRCNN': |
||||
logging.warning( |
||||
"MaskRCNN does not support MKL-DNN, MKL-DNN is forcibly disabled" |
||||
) |
||||
else: |
||||
try: |
||||
# cache 10 different shapes for mkldnn to avoid memory leak |
||||
config.set_mkldnn_cache_capacity(10) |
||||
config.enable_mkldnn() |
||||
config.set_cpu_math_library_num_threads(mkl_thread_num) |
||||
except Exception as e: |
||||
logging.warning( |
||||
"The current environment does not support MKL-DNN, MKL-DNN is disabled." |
||||
) |
||||
pass |
||||
|
||||
if not use_glog: |
||||
config.disable_glog_info() |
||||
if memory_optimize: |
||||
config.enable_memory_optim() |
||||
config.switch_use_feed_fetch_ops(False) |
||||
predictor = create_predictor(config) |
||||
return predictor |
||||
|
||||
def preprocess(self, images, transforms): |
||||
preprocessed_samples = self._model._preprocess( |
||||
images, transforms, to_tensor=False) |
||||
if self._model.model_type == 'classifier': |
||||
preprocessed_samples = {'image': preprocessed_samples[0]} |
||||
elif self._model.model_type == 'segmenter': |
||||
preprocessed_samples = { |
||||
'image': preprocessed_samples[0], |
||||
'ori_shape': preprocessed_samples[1] |
||||
} |
||||
elif self._model.model_type == 'detector': |
||||
pass |
||||
elif self._model.model_type == 'changedetector': |
||||
preprocessed_samples = { |
||||
'image': preprocessed_samples[0], |
||||
'image2': preprocessed_samples[1], |
||||
'ori_shape': preprocessed_samples[2] |
||||
} |
||||
else: |
||||
logging.error( |
||||
"Invalid model type {}".format(self._model.model_type), |
||||
exit=True) |
||||
return preprocessed_samples |
||||
|
||||
def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None): |
||||
if self._model.model_type == 'classifier': |
||||
true_topk = min(self._model.num_classes, topk) |
||||
preds = self._model._postprocess(net_outputs[0], true_topk) |
||||
elif self._model.model_type in ('segmenter', 'changedetector'): |
||||
label_map, score_map = self._model._postprocess( |
||||
net_outputs, |
||||
batch_origin_shape=ori_shape, |
||||
transforms=transforms.transforms) |
||||
preds = [{ |
||||
'label_map': l, |
||||
'score_map': s |
||||
} for l, s in zip(label_map, score_map)] |
||||
elif self._model.model_type == 'detector': |
||||
net_outputs = { |
||||
k: v |
||||
for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs) |
||||
} |
||||
preds = self._model._postprocess(net_outputs) |
||||
else: |
||||
logging.error( |
||||
"Invalid model type {}.".format(self._model.model_type), |
||||
exit=True) |
||||
|
||||
return preds |
||||
|
||||
def raw_predict(self, inputs): |
||||
""" 接受预处理过后的数据进行预测 |
||||
Args: |
||||
inputs(dict): 预处理过后的数据 |
||||
""" |
||||
input_names = self.predictor.get_input_names() |
||||
for name in input_names: |
||||
input_tensor = self.predictor.get_input_handle(name) |
||||
input_tensor.copy_from_cpu(inputs[name]) |
||||
|
||||
self.predictor.run() |
||||
output_names = self.predictor.get_output_names() |
||||
net_outputs = list() |
||||
for name in output_names: |
||||
output_tensor = self.predictor.get_output_handle(name) |
||||
net_outputs.append(output_tensor.copy_to_cpu()) |
||||
|
||||
return net_outputs |
||||
|
||||
def _run(self, images, topk=1, transforms=None): |
||||
self.timer.preprocess_time_s.start() |
||||
preprocessed_input = self.preprocess(images, transforms) |
||||
self.timer.preprocess_time_s.end(iter_num=len(images)) |
||||
|
||||
self.timer.inference_time_s.start() |
||||
net_outputs = self.raw_predict(preprocessed_input) |
||||
self.timer.inference_time_s.end(iter_num=1) |
||||
|
||||
self.timer.postprocess_time_s.start() |
||||
results = self.postprocess( |
||||
net_outputs, |
||||
topk, |
||||
ori_shape=preprocessed_input.get('ori_shape', None), |
||||
transforms=transforms) |
||||
self.timer.postprocess_time_s.end(iter_num=len(images)) |
||||
|
||||
return results |
||||
|
||||
def predict(self, |
||||
img_file, |
||||
topk=1, |
||||
transforms=None, |
||||
warmup_iters=0, |
||||
repeats=1): |
||||
""" 图片预测 |
||||
Args: |
||||
img_file(List[np.ndarray or str], str or np.ndarray): |
||||
对于场景分类、图像复原、目标检测和语义分割任务来说,该参数可为单一图像路径,或是解码后的、排列格式为(H, W, C) |
||||
且具有float32类型的BGR图像(表示为numpy的ndarray形式),或者是一组图像路径或np.ndarray对象构成的列表;对于变化检测 |
||||
任务来说,该参数可以为图像路径二元组(分别表示前后两个时相影像路径),或是两幅图像组成的二元组,或者是上述两种二元组 |
||||
之一构成的列表。 |
||||
topk(int): 场景分类模型预测时使用,表示预测前topk的结果。默认值为1。 |
||||
transforms (paddlex.transforms): 数据预处理操作。默认值为None, 即使用`model.yml`中保存的数据预处理操作。 |
||||
warmup_iters (int): 预热轮数,用于评估模型推理以及前后处理速度。若大于1,会预先重复预测warmup_iters,而后才开始正式的预测及其速度评估。默认为0。 |
||||
repeats (int): 重复次数,用于评估模型推理以及前后处理速度。若大于1,会预测repeats次取时间平均值。默认值为1。 |
||||
""" |
||||
if repeats < 1: |
||||
logging.error("`repeats` must be greater than 1.", exit=True) |
||||
if transforms is None and not hasattr(self._model, 'test_transforms'): |
||||
raise Exception("Transforms need to be defined, now is None.") |
||||
if transforms is None: |
||||
transforms = self._model.test_transforms |
||||
if isinstance(img_file, tuple) and len(img_file) != 2: |
||||
raise ValueError( |
||||
f"A change detection model accepts exactly two input images, but there are {len(img_file)}." |
||||
) |
||||
if isinstance(img_file, (str, np.ndarray, tuple)): |
||||
images = [img_file] |
||||
else: |
||||
images = img_file |
||||
|
||||
for _ in range(warmup_iters): |
||||
self._run(images=images, topk=topk, transforms=transforms) |
||||
self.timer.reset() |
||||
|
||||
for _ in range(repeats): |
||||
results = self._run(images=images, topk=topk, transforms=transforms) |
||||
|
||||
self.timer.repeats = repeats |
||||
self.timer.img_num = len(images) |
||||
self.timer.info(average=True) |
||||
|
||||
if isinstance(img_file, (str, np.ndarray)): |
||||
results = results[0] |
||||
|
||||
return results |
||||
|
||||
def batch_predict(self, image_list, **params): |
||||
return self.predict(img_file=image_list, **params) |
Loading…
Reference in new issue