From f5ec92e4ca3b1f293ee83d5feca1fb3b43bfde47 Mon Sep 17 00:00:00 2001 From: Yuantao Feng Date: Wed, 6 Dec 2023 04:43:54 -0600 Subject: [PATCH] Merge pull request #24655 from fengyuentau:graph_simplifier_optional_input dnn onnx graph simplifier: handle optional inputs of Slice #24655 Resolves https://github.com/opencv/opencv/issues/24609 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake --- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 76 ++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 484ee7c09e..7b9ec4082a 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -82,6 +82,23 @@ public: return makePtr(node); } + int getTensorShapeSize(int node_id, int node_input_id) { + const auto node = getNode(node_id); + const auto &input_name = node->getInputName(node_input_id); + for (int i = 0; i < net.value_info_size(); i++) { + const auto value_info = net.value_info(i); + if (value_info.name() == input_name) { + if (value_info.has_type() && value_info.type().has_tensor_type() && + value_info.type().tensor_type().has_shape()) { + return value_info.type().tensor_type().shape().dim_size(); + } else { + return -1; + } + } + } + return -1; + } + int getInputInitializerId(int node_id, int node_input_id) { auto node = getNode(node_id); @@ -164,6 +181,61 @@ static Mat extractConstant(const Ptr& net, int node_id, int } } +/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have + Slice with optional inputs of default values, some of them don't. This Subgraph removes + all optional inputs of Slice if values are default. +*/ +class RemoveSliceAllOptionalInputsSubgraph : public Subgraph { + public: + RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) { + num_inputs_ = num_inputs; + + int input = addNodeToMatch(""); + int starts = addNodeToMatch(""); + int ends = addNodeToMatch(""); + std::vector inputs{input, starts, ends}; + for (size_t i = 3; i < num_inputs_; i++) { // axes and steps + inputs.push_back(addNodeToMatch("")); + } + + slice_id = addNodeToMatch("Slice", inputs); + + setFusedNode("Slice", std::vector{input, starts, ends}); + } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds) CV_OVERRIDE { + if (Subgraph::match(net, nodeId, matchedNodesIds)) { + if (num_inputs_ >= 4) { // with axes + // Check if axes are -1 or last axis + auto onnx_net = net.dynamicCast(); + int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0); + + auto axes = extractConstant(net, matchedNodesIds[slice_id], 3); + for (size_t i = 0; i < axes.total(); i++) { + const int axis = *(axes.ptr() + i); + if (axis != -1 && axis != shape_size - 1) { + return false; + } + } + } + if (num_inputs_ == 5) { + // Check if steps are 1 + auto steps = extractConstant(net, matchedNodesIds[slice_id], 4); + if (countNonZero(steps != 1)) { + return false; + } + } + return true; + } + return false; + } + +private: + int slice_id; + size_t num_inputs_; +}; + /* Fusion for Gelu. Graph before fusion: @@ -1091,7 +1163,7 @@ public: int cast = addNodeToMatch("Cast", concat1); int shape2 = addNodeToMatch("Shape", input); - int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); + int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant")); int concat2 = addNodeToMatch("Concat", slice, cast); addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2); @@ -1163,6 +1235,8 @@ public: void simplifySubgraphs(opencv_onnx::GraphProto& net) { std::vector > subgraphs; + subgraphs.push_back(makePtr(4)); + subgraphs.push_back(makePtr(5)); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr());