[Feature]: Evaluation with TensorRT backend (#5198)

* evaluate trt models

* update version of onnx

* update maskrcnn results

* add backend argument

* update fcos results

* update

* fix bug

* update  doc
pull/5061/head
RunningLeon 4 years ago committed by GitHub
parent 52c935d27b
commit 8d40aefe04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 32
      docs/tutorials/onnx2tensorrt.md
  2. 73
      docs/tutorials/pytorch2onnx.md
  3. 170
      mmdet/core/export/model_wrappers.py
  4. 18
      mmdet/models/detectors/single_stage.py
  5. 6
      mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
  6. 25
      tools/deployment/onnx2tensorrt.py
  7. 2
      tools/deployment/pytorch2onnx.py
  8. 22
      tools/deployment/test.py

@ -6,6 +6,7 @@
- [How to convert models from ONNX to TensorRT](#how-to-convert-models-from-onnx-to-tensorrt)
- [Prerequisite](#prerequisite)
- [Usage](#usage)
- [How to evaluate the exported models](#how-to-evaluate-the-exported-models)
- [List of supported models convertable to TensorRT](#list-of-supported-models-convertable-to-tensorrt)
- [Reminders](#reminders)
- [FAQs](#faqs)
@ -28,6 +29,7 @@ python tools/deployment/onnx2tensorrt.py \
--trt-file ${TRT_FILE} \
--input-img ${INPUT_IMAGE_PATH} \
--shape ${IMAGE_SHAPE} \
--max-shape ${MAX_IMAGE_SHAPE} \
--mean ${IMAGE_MEAN} \
--std ${IMAGE_STD} \
--dataset ${DATASET_NAME} \
@ -42,6 +44,7 @@ Description of all arguments:
- `--trt-file`: The Path of output TensorRT engine file. If not specified, it will be set to `tmp.trt`.
- `--input-img` : The path of an input image for tracing and conversion. By default, it will be set to `demo/demo.jpg`.
- `--shape`: The height and width of model input. If not specified, it will be set to `400 600`.
- `--max-shape`: The maximum height and width of model input. If not specified, it will be set to the same as `--shape`.
- `--mean` : Three mean values for the input image. If not specified, it will be set to `123.675 116.28 103.53`.
- `--std` : Three std values for the input image. If not specified, it will be set to `58.395 57.12 57.375`.
- `--dataset` : The dataset name for the input model. If not specified, it will be set to `coco`.
@ -65,23 +68,32 @@ python tools/deployment/onnx2tensorrt.py \
--verify \
```
## How to evaluate the exported models
We prepare a tool `tools/deplopyment/test.py` to evaluate TensorRT models.
Please refer to following links for more information.
- [how-to-evaluate-the-exported-models](pytorch2onnx.md#how-to-evaluate-the-exported-models)
- [results-and-models](pytorch2onnx.md#results-and-models)
## List of supported models convertable to TensorRT
The table below lists the models that are guaranteed to be convertable to TensorRT.
| Model | Config | Status |
| :----------: | :--------------------------------------------------: | :----: |
| SSD | `configs/ssd/ssd300_coco.py` | Y |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py` | Y |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y |
| Model | Config | Dynamic Shape | Batch Inference | Note |
| :----------: | :--------------------------------------------------: | :-----------: | :-------------: | :---: |
| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py` | Y | Y | |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
Notes:
- *All models above are tested with Pytorch==1.6.0 and TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0*
- *All models above are tested with Pytorch==1.6.0, onnx==1.7.0 and TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0*
## Reminders

@ -3,19 +3,19 @@
<!-- TOC -->
- [Tutorial 8: Pytorch to ONNX (Experimental)](#tutorial-8-pytorch-to-onnx-experimental)
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
- [Prerequisite](#prerequisite)
- [Usage](#usage)
- [Description of all arguments](#description-of-all-arguments)
- [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime)
- [Prerequisite](#prerequisite-1)
- [Usage](#usage-1)
- [Description of all arguments](#description-of-all-arguments-1)
- [Results and Models](#results-and-models)
- [List of supported models exportable to ONNX](#list-of-supported-models-exportable-to-onnx)
- [The Parameters of Non-Maximum Suppression in ONNX Export](#the-parameters-of-non-maximum-suppression-in-onnx-export)
- [Reminders](#reminders)
- [FAQs](#faqs)
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
- [Prerequisite](#prerequisite)
- [Usage](#usage)
- [Description of all arguments](#description-of-all-arguments)
- [How to evaluate the exported models](#how-to-evaluate-the-exported-models)
- [Prerequisite](#prerequisite-1)
- [Usage](#usage-1)
- [Description of all arguments](#description-of-all-arguments-1)
- [Results and Models](#results-and-models)
- [List of supported models exportable to ONNX](#list-of-supported-models-exportable-to-onnx)
- [The Parameters of Non-Maximum Suppression in ONNX Export](#the-parameters-of-non-maximum-suppression-in-onnx-export)
- [Reminders](#reminders)
- [FAQs](#faqs)
<!-- TOC -->
@ -85,14 +85,12 @@ python tools/deployment/pytorch2onnx.py \
--verify \
--dynamic-export \
--cfg-options \
model.test_cfg.nms_pre=200 \
model.test_cfg.max_per_img=200 \
model.test_cfg.deploy_nms_pre=300 \
model.test_cfg.deploy_nms_pre=-1 \
```
## How to evaluate ONNX models with ONNX Runtime
## How to evaluate the exported models
We prepare a tool `tools/deplopyment/test.py` to evaluate ONNX models with ONNX Runtime backend.
We prepare a tool `tools/deplopyment/test.py` to evaluate ONNX models with ONNXRuntime and TensorRT.
### Prerequisite
@ -102,13 +100,16 @@ We prepare a tool `tools/deplopyment/test.py` to evaluate ONNX models with ONNX
pip install onnx onnxruntime-gpu
```
- Install TensorRT by referring to [how-to-build-tensorrt-plugins-in-mmcv](https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv)(optional)
### Usage
```bash
python tools/deployment/test.py \
${CONFIG_FILE} \
${ONNX_FILE} \
${MODEL_FILE} \
--out ${OUTPUT_FILE} \
--backend ${BACKEND} \
--format-only ${FORMAT_ONLY} \
--eval ${EVALUATION_METRICS} \
--show-dir ${SHOW_DIRECTORY} \
@ -120,8 +121,9 @@ python tools/deployment/test.py \
### Description of all arguments
- `config`: The path of a model config file.
- `model`: The path of a ONNX model file.
- `model`: The path of an input model file.
- `--out`: The path of output result file in pickle format.
- `--backend`: Backend for input model to run and should be `onnxruntime` or `tensorrt`.
- `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`.
- `--eval`: Evaluation metrics, which depends on the dataset, e.g., "bbox", "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC.
- `--show-dir`: Directory where painted images will be saved
@ -138,6 +140,7 @@ python tools/deployment/test.py \
<th align="center">Metric</th>
<th align="center">PyTorch</th>
<th align="center">ONNX Runtime</th>
<th align="center">TensorRT</th>
</tr >
<tr >
<td align="center">FCOS</td>
@ -145,6 +148,7 @@ python tools/deployment/test.py \
<td align="center">Box AP</td>
<td align="center">36.6</td>
<td align="center">36.5</td>
<td align="center">36.3</td>
</tr>
<tr >
<td align="center">FSAF</td>
@ -152,6 +156,7 @@ python tools/deployment/test.py \
<td align="center">Box AP</td>
<td align="center">36.0</td>
<td align="center">36.0</td>
<td align="center">35.9</td>
</tr>
<tr >
<td align="center">RetinaNet</td>
@ -159,6 +164,7 @@ python tools/deployment/test.py \
<td align="center">Box AP</td>
<td align="center">36.5</td>
<td align="center">36.4</td>
<td align="center">36.3</td>
</tr>
<tr >
<td align="center" align="center" >SSD</td>
@ -166,6 +172,7 @@ python tools/deployment/test.py \
<td align="center" align="center">Box AP</td>
<td align="center" align="center">25.6</td>
<td align="center" align="center">25.6</td>
<td align="center" align="center">25.6</td>
</tr>
<tr >
<td align="center">YOLOv3</td>
@ -173,6 +180,7 @@ python tools/deployment/test.py \
<td align="center">Box AP</td>
<td align="center">33.5</td>
<td align="center">33.5</td>
<td align="center">33.5</td>
</tr>
<tr >
<td align="center">Faster R-CNN</td>
@ -180,6 +188,7 @@ python tools/deployment/test.py \
<td align="center">Box AP</td>
<td align="center">37.4</td>
<td align="center">37.4</td>
<td align="center">37.0</td>
</tr>
<tr >
<td align="center" rowspan="2">Mask R-CNN</td>
@ -187,11 +196,13 @@ python tools/deployment/test.py \
<td align="center">Box AP</td>
<td align="center">38.2</td>
<td align="center">38.1</td>
<td align="center">37.7</td>
</tr>
<tr>
<td align="center">Mask AP</td>
<td align="center">34.7</td>
<td align="center">33.7</td>
<td align="center">33.3</td>
</tr>
<tr >
<td align="center">CornerNet</td>
@ -206,25 +217,27 @@ Notes:
- All ONNX models are evaluated with dynamic shape on coco dataset and images are preprocessed according to the original config file. Note that CornerNet is evaluated without test-time flip, since currently only single-scale evaluation is supported with ONNX Runtime.
- Mask AP of Mask R-CNN drops by 1% for ONNXRuntime. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in ONNXRuntime.
- Mask AP of Mask R-CNN drops by 1% for ONNXRuntime. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in other backend.
## List of supported models exportable to ONNX
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
| Model | Config | Dynamic Shape | Batch Inference | Note |
| :----------: | :-----------------------------------------------------------------: | :-----------: | :-------------: | :-----: |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | Y | |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | |
| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Model | Config | Dynamic Shape | Batch Inference | Note |
| :----------: | :-----------------------------------------------------------------: | :-----------: | :-------------: | :---------------------------------------------------------------------------: |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | Y | |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | |
| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| CornerNet | `configs/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco.py` | Y | N | no flip, no batch inference, tested with torch==1.7.0 and onnxruntime==1.5.1. |
Notes:
- Minimum required version of MMCV is `1.3.5`
- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1*, except for CornerNet. For more details about the
torch version when exporting CornerNet to ONNX, which involves `mmcv::cummax`, please refer to the [Known Issues](https://github.com/open-mmlab/mmcv/blob/master/docs/onnxruntime_op.md#known-issues) in mmcv.

@ -2,18 +2,103 @@ import os.path as osp
import warnings
import numpy as np
import onnxruntime as ort
import torch
from mmdet.core import bbox2result
from mmdet.models import BaseDetector
class ONNXRuntimeDetector(BaseDetector):
class DeployBaseDetector(BaseDetector):
"""DeployBaseDetector."""
def __init__(self, class_names, device_id):
super(DeployBaseDetector, self).__init__()
self.CLASSES = class_names
self.device_id = device_id
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def val_step(self, data, optimizer):
raise NotImplementedError('This method is not implemented.')
def train_step(self, data, optimizer):
raise NotImplementedError('This method is not implemented.')
def aforward_test(self, *, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def async_simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def forward(self, img, img_metas, return_loss=True, **kwargs):
outputs = self.forward_test(img, img_metas, **kwargs)
batch_dets, batch_labels = outputs[:2]
batch_masks = outputs[2] if len(outputs) == 3 else None
batch_size = img[0].shape[0]
img_metas = img_metas[0]
results = []
rescale = kwargs.get('rescale', True)
for i in range(batch_size):
dets, labels = batch_dets[i], batch_labels[i]
if rescale:
scale_factor = img_metas[i]['scale_factor']
if isinstance(scale_factor, (list, tuple, np.ndarray)):
assert len(scale_factor) == 4
scale_factor = np.array(scale_factor)[None, :] # [1,4]
dets[:, :4] /= scale_factor
if 'border' in img_metas[i]:
# offset pixel of the top-left corners between original image
# and padded/enlarged image, 'border' is used when exporting
# CornerNet and CentripetalNet to onnx
x_off = img_metas[i]['border'][2]
y_off = img_metas[i]['border'][0]
dets[:, [0, 2]] -= x_off
dets[:, [1, 3]] -= y_off
dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype)
dets_results = bbox2result(dets, labels, len(self.CLASSES))
if batch_masks is not None:
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
masks = masks[:, :img_h, :img_w]
if rescale:
masks = masks.astype(np.float32)
masks = torch.from_numpy(masks)
masks = torch.nn.functional.interpolate(
masks.unsqueeze(0), size=(ori_h, ori_w))
masks = masks.squeeze(0).detach().numpy()
if masks.dtype != np.bool:
masks = masks >= 0.5
segms_results = [[] for _ in range(len(self.CLASSES))]
for j in range(len(dets)):
segms_results[labels[j]].append(masks[j])
results.append((dets_results, segms_results))
else:
results.append(dets_results)
return results
class ONNXRuntimeDetector(DeployBaseDetector):
"""Wrapper for detector's inference with ONNXRuntime."""
def __init__(self, onnx_file, class_names, device_id):
super(ONNXRuntimeDetector, self).__init__()
super(ONNXRuntimeDetector, self).__init__(class_names, device_id)
import onnxruntime as ort
# get the custom op path
ort_custom_op_path = ''
try:
@ -37,25 +122,12 @@ class ONNXRuntimeDetector(BaseDetector):
sess.set_providers(providers, options)
self.sess = sess
self.CLASSES = class_names
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
self.is_cuda_available = is_cuda_available
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def forward_test(self, imgs, img_metas, **kwargs):
input_data = imgs[0]
img_metas = img_metas[0]
batch_size = input_data.shape[0]
# set io binding for inputs/outputs
device_type = 'cuda' if self.is_cuda_available else 'cpu'
if not self.is_cuda_available:
@ -73,46 +145,28 @@ class ONNXRuntimeDetector(BaseDetector):
# run session to get outputs
self.sess.run_with_iobinding(self.io_binding)
ort_outputs = self.io_binding.copy_outputs_to_cpu()
batch_dets, batch_labels = ort_outputs[:2]
batch_masks = ort_outputs[2] if len(ort_outputs) == 3 else None
return ort_outputs
results = []
for i in range(batch_size):
scale_factor = img_metas[i]['scale_factor']
dets, labels = batch_dets[i], batch_labels[i]
if isinstance(scale_factor, (list, tuple, np.ndarray)):
assert len(scale_factor) == 4
scale_factor = np.array(scale_factor)[None, :] # [1,4]
dets[:, :4] /= scale_factor
if 'border' in img_metas[i]:
# offset pixel of the top-left corners between original image
# and padded/enlarged image, 'border' is used when exporting
# CornerNet and CentripetalNet to onnx
x_off = img_metas[i]['border'][2]
y_off = img_metas[i]['border'][0]
dets[:, [0, 2]] -= x_off
dets[:, [1, 3]] -= y_off
dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype)
dets_results = bbox2result(dets, labels, len(self.CLASSES))
if batch_masks is not None:
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
masks = masks[:, :img_h, :img_w]
mask_dtype = masks.dtype
masks = masks.astype(np.float32)
masks = torch.from_numpy(masks)
masks = torch.nn.functional.interpolate(
masks.unsqueeze(0), size=(ori_h, ori_w))
masks = masks.squeeze(0).detach().numpy()
# convert mask to range(0,1)
if mask_dtype != np.bool:
masks /= 255
masks = masks >= 0.5
segms_results = [[] for _ in range(len(self.CLASSES))]
for j in range(len(dets)):
segms_results[labels[j]].append(masks[j])
results.append((dets_results, segms_results))
else:
results.append(dets_results)
return results
class TensorRTDetector(DeployBaseDetector):
"""Wrapper for detector's inference with TensorRT."""
def __init__(self, engine_file, class_names, device_id, output_names):
super(TensorRTDetector, self).__init__(class_names, device_id)
try:
from mmcv.tensorrt import TRTWraper
except (ImportError, ModuleNotFoundError):
raise RuntimeError(
'Please install TensorRT: https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv' # noqa
)
self.output_names = output_names
self.model = TRTWraper(engine_file, ['input'], output_names)
def forward_test(self, imgs, img_metas, **kwargs):
input_data = imgs[0]
with torch.cuda.device(self.device_id), torch.no_grad():
outputs = self.model({'input': input_data})
outputs = [outputs[name] for name in self.output_names]
outputs = [out.detach().cpu().numpy() for out in outputs]
return outputs

@ -94,22 +94,8 @@ class SingleStageDetector(BaseDetector):
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
# get origin input shape to support onnx dynamic shape
if torch.onnx.is_in_onnx_export():
# get shape as tensor
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape
# get pad input shape to support onnx dynamic shape for exporting
# `CornerNet` and `CentripetalNet`, which 'pad_shape' is used
# for inference
img_metas[0]['pad_shape_for_onnx'] = img_shape
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
@ -159,6 +145,10 @@ class SingleStageDetector(BaseDetector):
# get shape as tensor
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape
# get pad input shape to support onnx dynamic shape for exporting
# `CornerNet` and `CentripetalNet`, which 'pad_shape' is used
# for inference
img_metas[0]['pad_shape_for_onnx'] = img_shape
# TODO:move all onnx related code in bbox_head to onnx_export function
det_bboxes, det_labels = self.bbox_head.get_bboxes(*outs, img_metas)

@ -1,4 +1,3 @@
import os
from warnings import warn
import numpy as np
@ -337,11 +336,6 @@ class FCNMaskHead(BaseModule):
mask_pred, bboxes, img_h, img_w, skip_empty=False)
if threshold >= 0:
masks = (masks >= threshold).to(dtype=torch.bool)
else:
# TensorRT backend does not have data type of uint8
is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
target_dtype = torch.int32 if is_trt_backend else torch.uint8
masks = (masks * 255).to(dtype=target_dtype)
return masks

@ -31,8 +31,9 @@ def onnx2tensorrt(onnx_file,
import tensorrt as trt
onnx_model = onnx.load(onnx_file)
input_shape = input_config['input_shape']
max_shape = input_config['max_shape']
# create trt engine and wraper
opt_shape_dict = {'input': [input_shape, input_shape, input_shape]}
opt_shape_dict = {'input': [input_shape, input_shape, max_shape]}
max_workspace_size = get_GiB(workspace_size)
trt_engine = onnx2trt(
onnx_model,
@ -84,6 +85,9 @@ def onnx2tensorrt(onnx_file,
output shapes: {trt_shapes}')
trt_masks = trt_outputs[2] if with_mask else None
if trt_masks is not None and trt_masks.dtype != np.bool:
trt_masks = trt_masks >= 0.5
ort_masks = ort_masks >= 0.5
# Show detection outputs
if show:
CLASSES = get_classes(dataset)
@ -148,6 +152,12 @@ def parse_args():
nargs='+',
default=[400, 600],
help='Input size of the model')
parser.add_argument(
'--max-shape',
type=int,
nargs='+',
default=None,
help='Maximum input size of the model in TensorRT')
parser.add_argument(
'--mean',
type=float,
@ -184,6 +194,16 @@ if __name__ == '__main__':
else:
raise ValueError('invalid input shape')
if not args.max_shape:
max_shape = input_shape
else:
if len(args.max_shape) == 1:
max_shape = (1, 3, args.max_shape[0], args.max_shape[0])
elif len(args.max_shape) == 2:
max_shape = (1, 3) + tuple(args.max_shape)
else:
raise ValueError('invalid input max_shape')
assert len(args.mean) == 3
assert len(args.std) == 3
@ -191,7 +211,8 @@ if __name__ == '__main__':
input_config = {
'input_shape': input_shape,
'input_path': args.input_img,
'normalize_cfg': normalize_cfg
'normalize_cfg': normalize_cfg,
'max_shape': max_shape
}
# Create TensorRT engine

@ -150,6 +150,8 @@ def pytorch2onnx(config_path,
onnx_results = bbox2result(ort_dets, ort_labels, num_classes)
if model.with_mask:
segm_results = onnx_outputs[2]
if segm_results.dtype != np.bool:
segm_results = (segm_results * 255).astype(np.uint8)
cls_segms = [[] for _ in range(num_classes)]
for i in range(ort_dets.shape[0]):
cls_segms[ort_labels[i]].append(segm_results[i])

@ -5,7 +5,6 @@ from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel
from mmdet.apis import single_gpu_test
from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
@ -22,6 +21,11 @@ def parse_args():
help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and '
'submit it to the test server')
parser.add_argument(
'--backend',
required=True,
choices=['onnxruntime', 'tensorrt'],
help='Backend for input model to run. ')
parser.add_argument(
'--eval',
type=str,
@ -103,8 +107,20 @@ def main():
dist=False,
shuffle=False)
model = ONNXRuntimeDetector(
args.model, class_names=dataset.CLASSES, device_id=0)
if args.backend == 'onnxruntime':
from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
model = ONNXRuntimeDetector(
args.model, class_names=dataset.CLASSES, device_id=0)
elif args.backend == 'tensorrt':
from mmdet.core.export.model_wrappers import TensorRTDetector
output_names = ['dets', 'labels']
if len(cfg.evaluation['metric']) == 2:
output_names.append('masks')
model = TensorRTDetector(
args.model,
class_names=dataset.CLASSES,
device_id=0,
output_names=output_names)
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,

Loading…
Cancel
Save