[Enhance]: Support some detection models exportable to onnx (#4087)

* Modify to export to ONNX

* Modify to support new_tensor, new_full when exporting to onnx

* Replace grid_sample with resize for exporting to onnx

* Add verify of mask result

* Verify all outputs between pytorch and onnx

* Avoid runtime split error for mask-rcnn

* Remove unnecessary change in bbox2roi
pull/4176/head
RunningLeon 5 years ago committed by GitHub
parent 0293aedb67
commit 63772c5cb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      mmdet/apis/inference.py
  2. 2
      mmdet/core/bbox/coder/tblr_bbox_coder.py
  3. 4
      mmdet/core/bbox/transforms.py
  4. 7
      mmdet/core/export/pytorch2onnx.py
  5. 30
      mmdet/core/post_processing/bbox_nms.py
  6. 3
      mmdet/models/dense_heads/yolo_head.py
  7. 5
      mmdet/models/detectors/base.py
  8. 3
      mmdet/models/roi_heads/cascade_roi_head.py
  9. 15
      mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
  10. 27
      mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
  11. 8
      mmdet/models/roi_heads/standard_roi_head.py
  12. 38
      mmdet/models/roi_heads/test_mixins.py
  13. 89
      tools/pytorch2onnx.py

@ -154,7 +154,13 @@ async def async_inference_detector(model, img):
return result
def show_result_pyplot(model, img, result, score_thr=0.3, fig_size=(15, 10)):
def show_result_pyplot(model,
img,
result,
score_thr=0.3,
fig_size=(15, 10),
title='result',
block=True):
"""Visualize the detection results on the image.
Args:
@ -164,10 +170,14 @@ def show_result_pyplot(model, img, result, score_thr=0.3, fig_size=(15, 10)):
(bbox, segm) or just bbox.
score_thr (float): The threshold to visualize the bboxes and masks.
fig_size (tuple): Figure size of the pyplot figure.
title (str): Title of the pyplot figure.
block (bool): Whether to block GUI.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
plt.title(title)
plt.tight_layout()
plt.show(block=block)

@ -158,7 +158,7 @@ def tblr2bboxes(priors,
w, h = torch.split(wh, 1, dim=1)
loc_decode[:, :2] *= h # tb
loc_decode[:, 2:] *= w # lr
top, bottom, left, right = loc_decode.split(1, dim=1)
top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=1)
xmin = prior_centers[:, 0].unsqueeze(1) - left
xmax = prior_centers[:, 0].unsqueeze(1) + right
ymin = prior_centers[:, 1].unsqueeze(1) - top

@ -111,8 +111,8 @@ def bbox2result(bboxes, labels, num_classes):
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
else:
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
bboxes = bboxes.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)]

@ -119,12 +119,14 @@ def preprocess_example_input(input_config):
input_path = input_config['input_path']
input_shape = input_config['input_shape']
one_img = mmcv.imread(input_path)
one_img = mmcv.imresize(one_img, input_shape[2:][::-1])
show_img = one_img.copy()
if 'normalize_cfg' in input_config.keys():
normalize_cfg = input_config['normalize_cfg']
mean = np.array(normalize_cfg['mean'], dtype=np.float32)
std = np.array(normalize_cfg['std'], dtype=np.float32)
one_img = mmcv.imnormalize(one_img, mean, std)
one_img = mmcv.imresize(one_img, input_shape[2:][::-1]).transpose(2, 0, 1)
one_img = one_img.transpose(2, 0, 1)
one_img = torch.from_numpy(one_img).unsqueeze(0).float().requires_grad_(
True)
(_, C, H, W) = input_shape
@ -134,7 +136,8 @@ def preprocess_example_input(input_config):
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False
'flip': False,
'show_img': show_img,
}
return one_img, one_meta

@ -35,33 +35,29 @@ def multiclass_nms(multi_bboxes,
else:
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4)
scores = multi_scores[:, :-1]
# filter out boxes with low scores
valid_mask = scores > score_thr
# We use masked_select for ONNX exporting purpose,
# which is equivalent to bboxes = bboxes[valid_mask]
# (TODO): as ONNX does not support repeat now,
# we have to use this ugly code
bboxes = torch.masked_select(
bboxes,
torch.stack((valid_mask, valid_mask, valid_mask, valid_mask),
-1)).view(-1, 4)
scores = multi_scores[:, :-1]
if score_factors is not None:
scores = scores * score_factors[:, None]
scores = torch.masked_select(scores, valid_mask)
labels = valid_mask.nonzero(as_tuple=False)[:, 1]
if bboxes.numel() == 0:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
labels = torch.arange(num_classes, dtype=torch.long)
labels = labels.view(1, -1).expand_as(scores)
bboxes = bboxes.reshape(-1, 4)
scores = scores.reshape(-1)
labels = labels.reshape(-1)
# remove low scoring boxes
valid_mask = scores > score_thr
inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
if inds.numel() == 0:
if torch.onnx.is_in_onnx_export():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels
# TODO: add size check before feed into batched_nms
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
if max_num > 0:

@ -277,7 +277,8 @@ class YOLOV3Head(BaseDenseHead, BBoxTestMixin):
# Get top-k prediction
nms_pre = cfg.get('nms_pre', -1)
if 0 < nms_pre < conf_pred.size(0):
if 0 < nms_pre < conf_pred.size(0) and (
not torch.onnx.is_in_onnx_export()):
_, topk_inds = conf_pred.topk(nms_pre)
bbox_pred = bbox_pred[topk_inds, :]
cls_pred = cls_pred[topk_inds, :]

@ -315,7 +315,10 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
for i in inds:
i = int(i)
color_mask = color_masks[labels[i]]
mask = segms[i].astype(bool)
sg = segms[i]
if isinstance(sg, torch.Tensor):
sg = sg.detach().cpu().numpy()
mask = sg.astype(bool)
img[mask] = img[mask] * 0.5 + color_mask * 0.5
# if out_file specified, do not show image in window
if out_file is not None:

@ -348,6 +348,9 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
cfg=rcnn_test_cfg)
det_bboxes.append(det_bbox)
det_labels.append(det_label)
if torch.onnx.is_in_onnx_export():
return det_bboxes, det_labels
bbox_results = [
bbox2result(det_bboxes[i], det_labels[i],
self.bbox_head[-1].num_classes)

@ -195,6 +195,16 @@ class FCNMaskHead(nn.Module):
scale_factor = bboxes.new_tensor(scale_factor)
bboxes = bboxes / scale_factor
if torch.onnx.is_in_onnx_export():
# TODO: Remove after F.grid_sample is supported.
from torchvision.models.detection.roi_heads \
import paste_masks_in_image
masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
thr = rcnn_test_cfg.get('mask_thr_binary', 0)
if thr > 0:
masks = masks >= thr
return masks
N = len(mask_pred)
# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
@ -240,7 +250,7 @@ class FCNMaskHead(nn.Module):
im_mask[(inds, ) + spatial_inds] = masks_chunk
for i in range(N):
cls_segms[labels[i]].append(im_mask[i].cpu().numpy())
cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
return cls_segms
@ -306,6 +316,9 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)
if torch.onnx.is_in_onnx_export():
raise RuntimeError(
'Exporting F.grid_sample from Pytorch to ONNX is not supported.')
img_masks = F.grid_sample(
masks.to(dtype=torch.float32), grid, align_corners=False)

@ -55,8 +55,16 @@ class SingleRoIExtractor(BaseRoIExtractor):
"""Forward function."""
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, *out_size)
if torch.onnx.is_in_onnx_export():
# Work around to export mask-rcnn to onnx
roi_feats = rois[:, :1].clone().detach()
roi_feats = roi_feats.expand(
-1, self.out_channels * out_size[0] * out_size[1])
roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size)
roi_feats = roi_feats * 0
else:
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, *out_size)
# TODO: remove this when parrots supports
if torch.__version__ == 'parrots':
roi_feats.requires_grad = True
@ -69,10 +77,19 @@ class SingleRoIExtractor(BaseRoIExtractor):
target_lvls = self.map_roi_levels(rois, num_levels)
if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)
for i in range(num_levels):
inds = target_lvls == i
if inds.any():
rois_ = rois[inds, :]
mask = target_lvls == i
inds = mask.nonzero(as_tuple=False).squeeze(1)
# TODO: make it nicer when exporting to onnx
if torch.onnx.is_in_onnx_export():
# To keep all roi_align nodes exported to onnx
rois_ = rois[inds]
roi_feats_t = self.roi_layers[i](feats[i], rois_)
roi_feats[inds] = roi_feats_t
continue
if inds.numel() > 0:
rois_ = rois[inds]
roi_feats_t = self.roi_layers[i](feats[i], rois_)
roi_feats[inds] = roi_feats_t
else:

@ -246,6 +246,14 @@ class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
if torch.onnx.is_in_onnx_export():
if self.with_mask:
segm_results = self.simple_test_mask(
x, img_metas, det_bboxes, det_labels, rescale=rescale)
return det_bboxes, det_labels, segm_results
else:
return det_bboxes, det_labels
bbox_results = [
bbox2result(det_bboxes[i], det_labels[i],
self.bbox_head.num_classes)

@ -194,17 +194,33 @@ class MaskTestMixin(object):
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_rois = bbox2roi(_bboxes)
mask_results = self._mask_forward(x, mask_rois)
mask_pred = mask_results['mask_pred']
# split batch mask prediction back to each image
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
if torch.onnx.is_in_onnx_export():
# avoid mask_pred.split with static number of prediction
mask_preds = []
_bboxes = []
for i, boxes in enumerate(det_bboxes):
boxes = boxes[:, :4]
if rescale:
boxes *= scale_factors[i]
_bboxes.append(boxes)
img_inds = boxes[:, :1].clone() * 0 + i
mask_rois = torch.cat([img_inds, boxes], dim=-1)
mask_result = self._mask_forward(x, mask_rois)
mask_preds.append(mask_result['mask_pred'])
else:
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i][:, :4]
for i in range(len(det_bboxes))
]
mask_rois = bbox2roi(_bboxes)
mask_results = self._mask_forward(x, mask_rois)
mask_pred = mask_results['mask_pred']
# split batch mask prediction back to each image
num_mask_roi_per_img = [
det_bbox.shape[0] for det_bbox in det_bboxes
]
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
# apply mask post-processing to each image individually
segm_results = []

@ -18,7 +18,9 @@ def pytorch2onnx(config_path,
show=False,
output_file='tmp.onnx',
verify=False,
normalize_cfg=None):
normalize_cfg=None,
dataset='coco',
test_img=None):
input_config = {
'input_shape': input_shape,
@ -29,30 +31,44 @@ def pytorch2onnx(config_path,
# prepare original model and meta for verifying the onnx model
orig_model = build_model_from_cfg(config_path, checkpoint_path)
one_img, one_meta = preprocess_example_input(input_config)
model, tensor_data = generate_inputs_and_wrap_model(
config_path, checkpoint_path, input_config)
output_names = ['boxes']
if model.with_bbox:
output_names.append('labels')
if model.with_mask:
output_names.append('masks')
torch.onnx.export(
model,
tensor_data,
output_file,
input_names=['input'],
output_names=output_names,
export_params=True,
keep_initializers_as_inputs=True,
do_constant_folding=True,
verbose=show,
opset_version=opset_version)
model.forward = orig_model.forward
print(f'Successfully exported ONNX model: {output_file}')
if verify:
from mmdet.core import get_classes
from mmdet.apis import show_result_pyplot
model.CLASSES = get_classes(dataset)
num_classes = len(model.CLASSES)
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
if test_img is not None:
input_config['input_path'] = test_img
one_img, one_meta = preprocess_example_input(input_config)
tensor_data = [one_img]
# check the numerical value
# get pytorch output
pytorch_result = model(tensor_data, [[one_meta]], return_loss=False)
pytorch_results = model(tensor_data, [[one_meta]], return_loss=False)
pytorch_results = pytorch_results[0]
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
@ -62,14 +78,52 @@ def pytorch2onnx(config_path,
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
from mmdet.core import bbox2result
det_bboxes, det_labels = sess.run(
None, {net_feed_input[0]: one_img.detach().numpy()})
# only compare a part of result
bbox_results = bbox2result(det_bboxes, det_labels, 1)
onnx_results = bbox_results[0]
assert np.allclose(
pytorch_result[0][0][0][:4], onnx_results[0]
[:4]), 'The outputs are different between Pytorch and ONNX'
onnx_outputs = sess.run(None,
{net_feed_input[0]: one_img.detach().numpy()})
output_names = [_.name for _ in sess.get_outputs()]
output_shapes = [_.shape for _ in onnx_outputs]
print(f'onnxruntime output names: {output_names}, \
output shapes: {output_shapes}')
nrof_out = len(onnx_outputs)
assert nrof_out > 0, 'Must have output'
with_mask = nrof_out == 3
if nrof_out == 1:
onnx_results = onnx_outputs[0]
else:
det_bboxes, det_labels = onnx_outputs[:2]
onnx_results = bbox2result(det_bboxes, det_labels, num_classes)
if with_mask:
segm_results = onnx_outputs[2].squeeze(1)
cls_segms = [[] for _ in range(num_classes)]
for i in range(det_bboxes.shape[0]):
cls_segms[det_labels[i]].append(segm_results[i])
onnx_results = (onnx_results, cls_segms)
# visualize predictions
if show:
show_result_pyplot(
model,
one_meta['show_img'],
pytorch_results,
title='Pytorch',
block=False)
show_result_pyplot(
model, one_meta['show_img'], onnx_results, title='ONNX')
# compare a part of result
if with_mask:
compare_pairs = list(zip(onnx_results, pytorch_results))
else:
compare_pairs = [(onnx_results, pytorch_results)]
for onnx_res, pytorch_res in compare_pairs:
for o_res, p_res in zip(onnx_res, pytorch_res):
np.testing.assert_allclose(
o_res,
p_res,
rtol=1e-03,
atol=1e-05,
)
print('The numerical values are the same between Pytorch and ONNX')
@ -82,6 +136,12 @@ def parse_args():
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--test-img', type=str, default=None, help='Images for test')
parser.add_argument(
'--dataset', type=str, default='coco', help='Dataset name')
parser.add_argument(
'--view', action='store_true', help='Visualize results')
parser.add_argument(
'--verify',
action='store_true',
@ -139,4 +199,5 @@ if __name__ == '__main__':
show=args.show,
output_file=args.output_file,
verify=args.verify,
normalize_cfg=normalize_cfg)
normalize_cfg=normalize_cfg,
dataset=args.dataset)

Loading…
Cancel
Save