mirror of https://github.com/opencv/opencv.git
parent
9da9e8244b
commit
ea2527c2d1
8 changed files with 206 additions and 3 deletions
@ -0,0 +1,43 @@ |
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level
|
||||||
|
// directory of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#ifndef OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP |
||||||
|
#define OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP |
||||||
|
|
||||||
|
#include <opencv2/gapi/gkernel.hpp> // GKernelPackage |
||||||
|
#include <opencv2/gapi/infer/onnx.hpp> // Params |
||||||
|
#include "opencv2/gapi/own/exports.hpp" // GAPI_EXPORTS |
||||||
|
#include <opencv2/gapi/util/any.hpp> |
||||||
|
|
||||||
|
#include <string> |
||||||
|
|
||||||
|
namespace cv { |
||||||
|
namespace gapi { |
||||||
|
namespace onnx { |
||||||
|
|
||||||
|
// NB: Used by python wrapper
|
||||||
|
// This class can be marked as SIMPLE, because it's implemented as pimpl
|
||||||
|
class GAPI_EXPORTS_W_SIMPLE PyParams { |
||||||
|
public: |
||||||
|
GAPI_WRAP |
||||||
|
PyParams() = default; |
||||||
|
|
||||||
|
GAPI_WRAP |
||||||
|
PyParams(const std::string& tag, const std::string& model_path); |
||||||
|
|
||||||
|
GBackend backend() const; |
||||||
|
std::string tag() const; |
||||||
|
cv::util::any params() const; |
||||||
|
|
||||||
|
private: |
||||||
|
std::shared_ptr<Params<cv::gapi::Generic>> m_priv; |
||||||
|
}; |
||||||
|
|
||||||
|
GAPI_EXPORTS_W PyParams params(const std::string& tag, const std::string& model_path); |
||||||
|
|
||||||
|
} // namespace onnx
|
||||||
|
} // namespace gapi
|
||||||
|
} // namespace cv
|
||||||
|
|
||||||
|
#endif // OPENCV_GAPI_INFER_BINDINGS_ONNX_HPP
|
@ -0,0 +1,74 @@ |
|||||||
|
#!/usr/bin/env python |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import cv2 as cv |
||||||
|
import os |
||||||
|
import sys |
||||||
|
import unittest |
||||||
|
|
||||||
|
from tests_common import NewOpenCVTests |
||||||
|
|
||||||
|
|
||||||
|
try: |
||||||
|
|
||||||
|
if sys.version_info[:2] < (3, 0): |
||||||
|
raise unittest.SkipTest('Python 2.x is not supported') |
||||||
|
|
||||||
|
CLASSIFICATION_MODEL_PATH = "onnx_models/vision/classification/squeezenet/model/squeezenet1.0-9.onnx" |
||||||
|
|
||||||
|
testdata_required = bool(os.environ.get('OPENCV_DNN_TEST_REQUIRE_TESTDATA', False)) |
||||||
|
|
||||||
|
class test_gapi_infer(NewOpenCVTests): |
||||||
|
def find_dnn_file(self, filename, required=None): |
||||||
|
if not required: |
||||||
|
required = testdata_required |
||||||
|
return self.find_file(filename, [os.environ.get('OPENCV_DNN_TEST_DATA_PATH', os.getcwd()), |
||||||
|
os.environ['OPENCV_TEST_DATA_PATH']], |
||||||
|
required=required) |
||||||
|
|
||||||
|
def test_onnx_classification(self): |
||||||
|
model_path = self.find_dnn_file(CLASSIFICATION_MODEL_PATH) |
||||||
|
|
||||||
|
if model_path is None: |
||||||
|
raise unittest.SkipTest("Missing DNN test file") |
||||||
|
|
||||||
|
in_mat = cv.imread( |
||||||
|
self.find_file("cv/dpm/cat.png", |
||||||
|
[os.environ.get('OPENCV_TEST_DATA_PATH')])) |
||||||
|
|
||||||
|
g_in = cv.GMat() |
||||||
|
g_infer_inputs = cv.GInferInputs() |
||||||
|
g_infer_inputs.setInput("data_0", g_in) |
||||||
|
g_infer_out = cv.gapi.infer("squeeze-net", g_infer_inputs) |
||||||
|
g_out = g_infer_out.at("softmaxout_1") |
||||||
|
|
||||||
|
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out)) |
||||||
|
|
||||||
|
net = cv.gapi.onnx.params("squeeze-net", model_path) |
||||||
|
try: |
||||||
|
out_gapi = comp.apply(cv.gin(in_mat), cv.gapi.compile_args(cv.gapi.networks(net))) |
||||||
|
except cv.error as err: |
||||||
|
if err.args[0] == "G-API has been compiled without ONNX support": |
||||||
|
raise unittest.SkipTest("G-API has been compiled without ONNX support") |
||||||
|
else: |
||||||
|
raise |
||||||
|
|
||||||
|
self.assertEqual((1, 1000, 1, 1), out_gapi.shape) |
||||||
|
|
||||||
|
|
||||||
|
except unittest.SkipTest as e: |
||||||
|
|
||||||
|
message = str(e) |
||||||
|
|
||||||
|
class TestSkip(unittest.TestCase): |
||||||
|
def setUp(self): |
||||||
|
self.skipTest('Skip tests: ' + message) |
||||||
|
|
||||||
|
def test_skip(): |
||||||
|
pass |
||||||
|
|
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
NewOpenCVTests.bootstrap() |
@ -0,0 +1,24 @@ |
|||||||
|
// This file is part of OpenCV project.
|
||||||
|
// It is subject to the license terms in the LICENSE file found in the top-level
|
||||||
|
// directory of this distribution and at http://opencv.org/license.html.
|
||||||
|
|
||||||
|
#include <opencv2/gapi/infer/bindings_onnx.hpp> |
||||||
|
|
||||||
|
cv::gapi::onnx::PyParams::PyParams(const std::string& tag, |
||||||
|
const std::string& model_path) |
||||||
|
: m_priv(std::make_shared<Params<cv::gapi::Generic>>(tag, model_path)) {} |
||||||
|
|
||||||
|
cv::gapi::GBackend cv::gapi::onnx::PyParams::backend() const { |
||||||
|
return m_priv->backend(); |
||||||
|
} |
||||||
|
|
||||||
|
std::string cv::gapi::onnx::PyParams::tag() const { return m_priv->tag(); } |
||||||
|
|
||||||
|
cv::util::any cv::gapi::onnx::PyParams::params() const { |
||||||
|
return m_priv->params(); |
||||||
|
} |
||||||
|
|
||||||
|
cv::gapi::onnx::PyParams cv::gapi::onnx::params( |
||||||
|
const std::string& tag, const std::string& model_path) { |
||||||
|
return {tag, model_path}; |
||||||
|
} |
Loading…
Reference in new issue