mirror of https://github.com/opencv/opencv.git
Merge pull request #22017 from xiong-jie-y:py_onnx
Add python bindings for G-API onnxpull/22501/head
commit
fef8d4c990
8 changed files with 206 additions and 3 deletions
@ -0,0 +1,43 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level
|
||||
// directory of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP |
||||
#define OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP |
||||
|
||||
#include <opencv2/gapi/gkernel.hpp> // GKernelPackage |
||||
#include <opencv2/gapi/infer/onnx.hpp> // Params |
||||
#include "opencv2/gapi/own/exports.hpp" // GAPI_EXPORTS |
||||
#include <opencv2/gapi/util/any.hpp> |
||||
|
||||
#include <string> |
||||
|
||||
namespace cv { |
||||
namespace gapi { |
||||
namespace onnx { |
||||
|
||||
// NB: Used by python wrapper
|
||||
// This class can be marked as SIMPLE, because it's implemented as pimpl
|
||||
class GAPI_EXPORTS_W_SIMPLE PyParams { |
||||
public: |
||||
GAPI_WRAP |
||||
PyParams() = default; |
||||
|
||||
GAPI_WRAP |
||||
PyParams(const std::string& tag, const std::string& model_path); |
||||
|
||||
GBackend backend() const; |
||||
std::string tag() const; |
||||
cv::util::any params() const; |
||||
|
||||
private: |
||||
std::shared_ptr<Params<cv::gapi::Generic>> m_priv; |
||||
}; |
||||
|
||||
GAPI_EXPORTS_W PyParams params(const std::string& tag, const std::string& model_path); |
||||
|
||||
} // namespace onnx
|
||||
} // namespace gapi
|
||||
} // namespace cv
|
||||
|
||||
#endif // OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP
|
@ -0,0 +1,74 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
import numpy as np |
||||
import cv2 as cv |
||||
import os |
||||
import sys |
||||
import unittest |
||||
|
||||
from tests_common import NewOpenCVTests |
||||
|
||||
|
||||
try: |
||||
|
||||
if sys.version_info[:2] < (3, 0): |
||||
raise unittest.SkipTest('Python 2.x is not supported') |
||||
|
||||
CLASSIFICATION_MODEL_PATH = "onnx_models/vision/classification/squeezenet/model/squeezenet1.0-9.onnx" |
||||
|
||||
testdata_required = bool(os.environ.get('OPENCV_DNN_TEST_REQUIRE_TESTDATA', False)) |
||||
|
||||
class test_gapi_infer(NewOpenCVTests): |
||||
def find_dnn_file(self, filename, required=None): |
||||
if not required: |
||||
required = testdata_required |
||||
return self.find_file(filename, [os.environ.get('OPENCV_DNN_TEST_DATA_PATH', os.getcwd()), |
||||
os.environ['OPENCV_TEST_DATA_PATH']], |
||||
required=required) |
||||
|
||||
def test_onnx_classification(self): |
||||
model_path = self.find_dnn_file(CLASSIFICATION_MODEL_PATH) |
||||
|
||||
if model_path is None: |
||||
raise unittest.SkipTest("Missing DNN test file") |
||||
|
||||
in_mat = cv.imread( |
||||
self.find_file("cv/dpm/cat.png", |
||||
[os.environ.get('OPENCV_TEST_DATA_PATH')])) |
||||
|
||||
g_in = cv.GMat() |
||||
g_infer_inputs = cv.GInferInputs() |
||||
g_infer_inputs.setInput("data_0", g_in) |
||||
g_infer_out = cv.gapi.infer("squeeze-net", g_infer_inputs) |
||||
g_out = g_infer_out.at("softmaxout_1") |
||||
|
||||
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out)) |
||||
|
||||
net = cv.gapi.onnx.params("squeeze-net", model_path) |
||||
try: |
||||
out_gapi = comp.apply(cv.gin(in_mat), cv.gapi.compile_args(cv.gapi.networks(net))) |
||||
except cv.error as err: |
||||
if err.args[0] == "G-API has been compiled without ONNX support": |
||||
raise unittest.SkipTest("G-API has been compiled without ONNX support") |
||||
else: |
||||
raise |
||||
|
||||
self.assertEqual((1, 1000, 1, 1), out_gapi.shape) |
||||
|
||||
|
||||
except unittest.SkipTest as e: |
||||
|
||||
message = str(e) |
||||
|
||||
class TestSkip(unittest.TestCase): |
||||
def setUp(self): |
||||
self.skipTest('Skip tests: ' + message) |
||||
|
||||
def test_skip(): |
||||
pass |
||||
|
||||
pass |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
NewOpenCVTests.bootstrap() |
@ -0,0 +1,24 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level
|
||||
// directory of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include <opencv2/gapi/infer/bindings_onnx.hpp> |
||||
|
||||
cv::gapi::onnx::PyParams::PyParams(const std::string& tag, |
||||
const std::string& model_path) |
||||
: m_priv(std::make_shared<Params<cv::gapi::Generic>>(tag, model_path)) {} |
||||
|
||||
cv::gapi::GBackend cv::gapi::onnx::PyParams::backend() const { |
||||
return m_priv->backend(); |
||||
} |
||||
|
||||
std::string cv::gapi::onnx::PyParams::tag() const { return m_priv->tag(); } |
||||
|
||||
cv::util::any cv::gapi::onnx::PyParams::params() const { |
||||
return m_priv->params(); |
||||
} |
||||
|
||||
cv::gapi::onnx::PyParams cv::gapi::onnx::params( |
||||
const std::string& tag, const std::string& model_path) { |
||||
return {tag, model_path}; |
||||
} |
Loading…
Reference in new issue