From 332748dd557fc79fa26964a38625bb376a7cb2e8 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 24 Nov 2023 10:40:32 +0300 Subject: [PATCH] Merge pull request #24577 from dkurt:dnn_graph_match_stack Fix graph fusion with commutative ops #24577 ### Pull Request Readiness Checklist resolves https://github.com/opencv/opencv/issues/24568 **Merge with extra**: https://github.com/opencv/opencv_extra/pull/1125 TODO: - [x] replace recursive function to sequential See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake --- modules/dnn/src/graph_simplifier.cpp | 111 +++++++++++------- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 6 + modules/dnn/src/onnx/onnx_importer.cpp | 2 +- modules/dnn/test/test_graph_simplifier.cpp | 1 + 4 files changed, 77 insertions(+), 43 deletions(-) diff --git a/modules/dnn/src/graph_simplifier.cpp b/modules/dnn/src/graph_simplifier.cpp index b9684afe69..2e1dc400be 100644 --- a/modules/dnn/src/graph_simplifier.cpp +++ b/modules/dnn/src/graph_simplifier.cpp @@ -81,26 +81,45 @@ bool Subgraph::match(const Ptr& net, int nodeId, { matchedNodesIds.clear(); - std::queue nodesToMatch; - std::queue targetNodes; - std::vector > matchings; - matchings.reserve(nodes.size()); - nodesToMatch.push(nodeId); - targetNodes.push(nodes.size() - 1); - while (!nodesToMatch.empty()) + // Collection of all matchings states across branching. + // If there is no commutative ops in the subgraph - there would be just a single map. + std::vector>> matchCandidates; + matchCandidates.push_back(makePtr>()); + + struct State + { + int nodeToMatch; + int targetNodeId; + // Every state refers to current matchings pairs as well as + // matchings from parent branches produced by commutative ops. + std::vector>> matchings; + + // When we register a matching pair we should register it in every parent branch. + // This is actual for branching in case of commutative ops only. + void addMatch(std::pair match) + { + for (auto& m : matchings) + m->insert(match); + } + }; + + std::queue states; + states.push({nodeId, (int)nodes.size() - 1, matchCandidates}); + + while (!states.empty()) { - int nodeToMatch = nodesToMatch.front(); - int targetNodeId = targetNodes.front(); - nodesToMatch.pop(); - targetNodes.pop(); + auto state = states.front(); + states.pop(); + int nodeToMatch = state.nodeToMatch; + int targetNodeId = state.targetNodeId; + auto matchings = state.matchings.back(); - if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair& match){ return match.first == targetNodeId; }) != - matchings.end()) + if (matchings->find(targetNodeId) != matchings->end()) continue; // Empty placeholder matches with any input type if (nodes[targetNodeId].empty()) { - matchings.push_back({targetNodeId, nodeToMatch}); + state.addMatch({targetNodeId, nodeToMatch}); continue; } @@ -112,42 +131,50 @@ bool Subgraph::match(const Ptr& net, int nodeId, if (inputNodes.size() != node->getNumInputs()) continue; - bool isCommutative = net->isCommutativeOp(node->getType()); + state.addMatch({targetNodeId, nodeToMatch}); - for (int j = 0; j < inputNodes.size(); ++j) + bool isCommutative = net->isCommutativeOp(node->getType()); + if (isCommutative) { - // Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase) - if (node->getInputName(j).empty()) - continue; - nodeId = getInputNodeId(net, node, j); - const Ptr inpNode = net->getNode(nodeId); - if (isCommutative) - { - for (int i = 0; i < inputNodes.size(); ++i) - { - nodesToMatch.push(nodeId); - targetNodes.push(inputNodes[i]); - } - } - else + if (inputNodes.size() != 2) + CV_Error(Error::StsNotImplemented, "Commutative op fusion with more than 2 inputs"); + + auto newMatchings = makePtr>(*matchings); + matchCandidates.push_back(newMatchings); + state.matchings.push_back(newMatchings); + states.push({getInputNodeId(net, node, 0), inputNodes[0], state.matchings}); + states.push({getInputNodeId(net, node, 1), inputNodes[1], state.matchings}); + state.matchings.pop_back(); + + newMatchings = makePtr>(*matchings); + matchCandidates.push_back(newMatchings); + state.matchings.push_back(newMatchings); + states.push({getInputNodeId(net, node, 0), inputNodes[1], state.matchings}); + states.push({getInputNodeId(net, node, 1), inputNodes[0], state.matchings}); + state.matchings.pop_back(); + } + else + { + for (int j = 0; j < inputNodes.size(); ++j) { - nodesToMatch.push(nodeId); - targetNodes.push(inputNodes[j]); + nodeId = getInputNodeId(net, node, j); + states.push({nodeId, inputNodes[j], state.matchings}); } } - matchings.push_back({targetNodeId, nodeToMatch}); } - if (matchings.size() != nodes.size()) - return false; - - // Sort matched by pattern nodes order. - std::sort(matchings.begin(), matchings.end()); - matchedNodesIds.resize(matchings.size()); - for (int i = 0; i < matchings.size(); ++i) + for (auto& matchings : matchCandidates) { - matchedNodesIds[i] = matchings[i].second; + if (matchings->size() != nodes.size()) + continue; + matchedNodesIds.resize(matchings->size()); + for (int i = 0; i < matchings->size(); ++i) + { + CV_Assert(matchings->find(i) != matchings->end()); + matchedNodesIds[i] = matchings->at(i); + } + return true; } - return true; + return false; } void Subgraph::replace(const Ptr& net, const std::vector& matchedNodesIds) diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index a8f4058d50..484ee7c09e 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -64,6 +64,12 @@ class ONNXGraphWrapper : public ImportGraphWrapper public: ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net) { + // Add a fake initializer with empty name. + // Some ONNX models skip their inputs. For example, + // Resize which has 4 inputs but 2 of them have empty names. + // So we add a fake empty node to which such ops may refer as input. + net.add_initializer(); + numInputs = net.input_size(); numInitializers = net.initializer_size(); } diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index ba4d542731..3d1b4d8b82 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -3539,7 +3539,7 @@ void ONNXImporter::parseQGemm(LayerParams& layerParams, const opencv_onnx::NodeP Mat bias; if (constBlobs.find(node_proto.input(6)) != constBlobs.end()) bias = getBlob(node_proto, 6); - else + if (bias.empty()) bias = Mat::zeros(1, outCn, CV_32S); Mat biasFused(1, outCn, CV_32S); diff --git a/modules/dnn/test/test_graph_simplifier.cpp b/modules/dnn/test/test_graph_simplifier.cpp index 58fffd3713..f6b85de230 100644 --- a/modules/dnn/test/test_graph_simplifier.cpp +++ b/modules/dnn/test/test_graph_simplifier.cpp @@ -35,6 +35,7 @@ class Test_Graph_Simplifier : public ::testing::Test { TEST_F(Test_Graph_Simplifier, GeluSubGraph) { test("gelu", "Gelu"); + test("bias_gelu", std::vector{"Gelu", "NaryEltwise"}); } TEST_F(Test_Graph_Simplifier, GeluApproximationSubGraph) {