diff --git a/modules/gapi/include/opencv2/gapi/infer/bindings_onnx.hpp b/modules/gapi/include/opencv2/gapi/infer/bindings_onnx.hpp index 0b6dab6a9d..f7bb259924 100644 --- a/modules/gapi/include/opencv2/gapi/infer/bindings_onnx.hpp +++ b/modules/gapi/include/opencv2/gapi/infer/bindings_onnx.hpp @@ -54,6 +54,9 @@ public: GAPI_WRAP PyParams& cfgSessionOptions(const std::map& options); + GAPI_WRAP + PyParams& cfgOptLevel(const int opt_level); + GBackend backend() const; std::string tag() const; cv::util::any params() const; diff --git a/modules/gapi/include/opencv2/gapi/infer/onnx.hpp b/modules/gapi/include/opencv2/gapi/infer/onnx.hpp index fd0f69a768..eb6316b446 100644 --- a/modules/gapi/include/opencv2/gapi/infer/onnx.hpp +++ b/modules/gapi/include/opencv2/gapi/infer/onnx.hpp @@ -15,6 +15,7 @@ #include #include +#include #include // GAPI_EXPORTS #include // GKernelPackage @@ -354,6 +355,7 @@ struct ParamDesc { std::map session_options; std::vector execution_providers; bool disable_mem_pattern; + cv::util::optional opt_level; }; } // namespace detail @@ -648,6 +650,17 @@ public: return *this; } + /** @brief Configures optimization level for ONNX Runtime. + + @param opt_level [optimization level]: Valid values are 0 (disable), 1 (basic), 2 (extended), 99 (all). + Please see onnxruntime_c_api.h (enum GraphOptimizationLevel) for the full list of all optimization levels. + @return the reference on modified object. + */ + Params& cfgOptLevel(const int opt_level) { + desc.opt_level = cv::util::make_optional(opt_level); + return *this; + } + // BEGIN(G-API's network parametrization API) GBackend backend() const { return cv::gapi::onnx::backend(); } std::string tag() const { return Net::tag(); } @@ -675,7 +688,7 @@ public: @param model_path path to model file (.onnx file). */ Params(const std::string& tag, const std::string& model_path) - : desc{model_path, 0u, 0u, {}, {}, {}, {}, {}, {}, {}, {}, {}, true, {}, {}, {}, {}, false}, m_tag(tag) {} + : desc{ model_path, 0u, 0u, {}, {}, {}, {}, {}, {}, {}, {}, {}, true, {}, {}, {}, {}, false, {} }, m_tag(tag) {} /** @see onnx::Params::cfgMeanStdDev. */ void cfgMeanStdDev(const std::string &layer, @@ -724,6 +737,11 @@ public: desc.session_options.insert(options.begin(), options.end()); } +/** @see onnx::Params::cfgOptLevel. */ + void cfgOptLevel(const int opt_level) { + desc.opt_level = cv::util::make_optional(opt_level); + } + // BEGIN(G-API's network parametrization API) GBackend backend() const { return cv::gapi::onnx::backend(); } std::string tag() const { return m_tag; } diff --git a/modules/gapi/src/backends/onnx/bindings_onnx.cpp b/modules/gapi/src/backends/onnx/bindings_onnx.cpp index 294ad8a3cc..5a2e3d2f6d 100644 --- a/modules/gapi/src/backends/onnx/bindings_onnx.cpp +++ b/modules/gapi/src/backends/onnx/bindings_onnx.cpp @@ -63,6 +63,12 @@ cv::gapi::onnx::PyParams::cfgSessionOptions(const std::mapcfgOptLevel(opt_level); + return *this; +} + cv::gapi::GBackend cv::gapi::onnx::PyParams::backend() const { return m_priv->backend(); } diff --git a/modules/gapi/src/backends/onnx/gonnxbackend.cpp b/modules/gapi/src/backends/onnx/gonnxbackend.cpp index 0d9a16a7bd..fc9b12b081 100644 --- a/modules/gapi/src/backends/onnx/gonnxbackend.cpp +++ b/modules/gapi/src/backends/onnx/gonnxbackend.cpp @@ -701,6 +701,26 @@ namespace cv { namespace gimpl { namespace onnx { +static GraphOptimizationLevel convertToGraphOptimizationLevel(const int opt_level) { + switch (opt_level) { + case ORT_DISABLE_ALL: + return ORT_DISABLE_ALL; + case ORT_ENABLE_BASIC: + return ORT_ENABLE_BASIC; + case ORT_ENABLE_EXTENDED: + return ORT_ENABLE_EXTENDED; + case ORT_ENABLE_ALL: + return ORT_ENABLE_ALL; + default: + if (opt_level > ORT_ENABLE_ALL) { // relax constraint + return ORT_ENABLE_ALL; + } + else { + cv::util::throw_error(std::invalid_argument("Invalid argument opt_level = " + std::to_string(opt_level))); + } + } +} + ONNXCompiled::ONNXCompiled(const gapi::onnx::detail::ParamDesc &pp) : params(pp) { // Validate input parameters before allocating any resources @@ -726,6 +746,10 @@ ONNXCompiled::ONNXCompiled(const gapi::onnx::detail::ParamDesc &pp) if (pp.disable_mem_pattern) { session_options.DisableMemPattern(); } + + if (pp.opt_level.has_value()) { + session_options.SetGraphOptimizationLevel(convertToGraphOptimizationLevel(pp.opt_level.value())); + } this_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, ""); #ifndef _WIN32 this_session = Ort::Session(this_env, params.model_path.data(), session_options);