OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1.3 KiB
42 lines
1.3 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
# Copyright (c) 2019 Western Digital Corporation or its affiliates. |
|
import torch |
|
|
|
from ..builder import DETECTORS |
|
from .single_stage import SingleStageDetector |
|
|
|
|
|
@DETECTORS.register_module() |
|
class YOLOV3(SingleStageDetector): |
|
|
|
def __init__(self, |
|
backbone, |
|
neck, |
|
bbox_head, |
|
train_cfg=None, |
|
test_cfg=None, |
|
pretrained=None, |
|
init_cfg=None): |
|
super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg, |
|
test_cfg, pretrained, init_cfg) |
|
|
|
def onnx_export(self, img, img_metas): |
|
"""Test function for exporting to ONNX, without test time augmentation. |
|
|
|
Args: |
|
img (torch.Tensor): input images. |
|
img_metas (list[dict]): List of image information. |
|
|
|
Returns: |
|
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] |
|
and class labels of shape [N, num_det]. |
|
""" |
|
x = self.extract_feat(img) |
|
outs = self.bbox_head.forward(x) |
|
# get shape as tensor |
|
img_shape = torch._shape_as_tensor(img)[2:] |
|
img_metas[0]['img_shape_for_onnx'] = img_shape |
|
|
|
det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas) |
|
|
|
return det_bboxes, det_labels
|
|
|