From b7ec2ebb55309a7b880edf5e5f483b537cbefe86 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Wed, 8 Nov 2023 16:26:33 +0300 Subject: [PATCH] Merge pull request #24483 from dkurt:dnn_fusion_commutative_ops Commutative rules for DNN subgraphs fusion #24483 ### Pull Request Readiness Checklist related: https://github.com/opencv/opencv/pull/24463#issuecomment-1783033931 See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake --- modules/dnn/src/graph_simplifier.cpp | 68 ++-- modules/dnn/src/graph_simplifier.hpp | 8 +- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 304 +++++++----------- .../src/tensorflow/tf_graph_simplifier.cpp | 27 +- 4 files changed, 179 insertions(+), 228 deletions(-) diff --git a/modules/dnn/src/graph_simplifier.cpp b/modules/dnn/src/graph_simplifier.cpp index e1b6d6df40..b9684afe69 100644 --- a/modules/dnn/src/graph_simplifier.cpp +++ b/modules/dnn/src/graph_simplifier.cpp @@ -77,14 +77,14 @@ int Subgraph::getInputNodeId(const Ptr& net, } bool Subgraph::match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) + std::vector& matchedNodesIds) { matchedNodesIds.clear(); - targetNodesIds.clear(); std::queue nodesToMatch; std::queue targetNodes; + std::vector > matchings; + matchings.reserve(nodes.size()); nodesToMatch.push(nodeId); targetNodes.push(nodes.size() - 1); while (!nodesToMatch.empty()) @@ -94,51 +94,63 @@ bool Subgraph::match(const Ptr& net, int nodeId, nodesToMatch.pop(); targetNodes.pop(); - if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) != - matchedNodesIds.end()) + if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair& match){ return match.first == targetNodeId; }) != + matchings.end()) continue; + // Empty placeholder matches with any input type + if (nodes[targetNodeId].empty()) { + matchings.push_back({targetNodeId, nodeToMatch}); + continue; + } + const Ptr node = net->getNode(nodeToMatch); if (node->getType() != nodes[targetNodeId]) - return false; + continue; std::vector& inputNodes = inputs[targetNodeId]; if (inputNodes.size() != node->getNumInputs()) - return false; + continue; + + bool isCommutative = net->isCommutativeOp(node->getType()); for (int j = 0; j < inputNodes.size(); ++j) { - if (nodes[inputNodes[j]].empty() || node->getInputName(j).empty()) // Unknown input node type. + // Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase) + if (node->getInputName(j).empty()) continue; nodeId = getInputNodeId(net, node, j); const Ptr inpNode = net->getNode(nodeId); - if (inpNode->getType() != "Const" && inpNode->getType() != "Constant") + if (isCommutative) + { + for (int i = 0; i < inputNodes.size(); ++i) + { + nodesToMatch.push(nodeId); + targetNodes.push(inputNodes[i]); + } + } + else { nodesToMatch.push(nodeId); targetNodes.push(inputNodes[j]); } - else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant") - return false; } - matchedNodesIds.push_back(nodeToMatch); - targetNodesIds.push_back(targetNodeId); + matchings.push_back({targetNodeId, nodeToMatch}); } + if (matchings.size() != nodes.size()) + return false; - const int n = matchedNodesIds.size(); - std::vector > elements(n); - for (int i = 0; i < n; ++i) - elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]); - std::sort(elements.begin(), elements.end()); - for (int i = 0; i < n; ++i) + // Sort matched by pattern nodes order. + std::sort(matchings.begin(), matchings.end()); + matchedNodesIds.resize(matchings.size()); + for (int i = 0; i < matchings.size(); ++i) { - matchedNodesIds[i] = elements[i].first; - targetNodesIds[i] = elements[i].second; + matchedNodesIds[i] = matchings[i].second; } return true; } -void Subgraph::replace(const Ptr& net, const std::vector& matchedNodesIds, - const std::vector& targetNodesIds) +void Subgraph::replace(const Ptr& net, const std::vector& matchedNodesIds) { // Extract names of input nodes. std::vector inputsNames(fusedNodeInputs.size()); @@ -149,9 +161,9 @@ void Subgraph::replace(const Ptr& net, const std::vector node = net->getNode(matchedNodesIds[j]); - std::vector& inpIndices = inputs[targetNodesIds[j]]; + std::vector& inpIndices = inputs[j]; - CV_Assert(node->getNumInputs() == inpIndices.size()); + CV_Assert(inpIndices.empty() || node->getNumInputs() == inpIndices.size()); for (int k = 0; k < inpIndices.size(); ++k) { if (inpIndices[k] == fusedNodeInputs[i]) @@ -187,15 +199,15 @@ void simplifySubgraphs(const Ptr& net, const std::vector >& patterns) { int numNodes = net->getNumNodes(); - std::vector matchedNodesIds, targetNodesIds; + std::vector matchedNodesIds; std::vector nodesToRemove; for (int j = 0; j < patterns.size(); ++j) { for (int i = 0; i < numNodes; ++i) { - if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) + if (patterns[j]->match(net, i, matchedNodesIds)) { - patterns[j]->replace(net, matchedNodesIds, targetNodesIds); + patterns[j]->replace(net, matchedNodesIds); // Remove matched nodes except the last one. nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1); } diff --git a/modules/dnn/src/graph_simplifier.hpp b/modules/dnn/src/graph_simplifier.hpp index 22bc50e3e5..aa9be32a91 100644 --- a/modules/dnn/src/graph_simplifier.hpp +++ b/modules/dnn/src/graph_simplifier.hpp @@ -44,6 +44,8 @@ public: virtual std::string getOutputName(int nodeId, int outId) const = 0; virtual void removeNode(int idx) = 0; + + virtual bool isCommutativeOp(const std::string& type) const = 0; }; class Subgraph // Interface to match and replace subgraphs. @@ -75,12 +77,10 @@ public: // Match TensorFlow subgraph starting from with a set of nodes to be fused. // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused. virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds); + std::vector& matchedNodesIds); // Fuse matched subgraph. - void replace(const Ptr& net, const std::vector& matchedNodesIds, - const std::vector& targetNodesIds); + void replace(const Ptr& net, const std::vector& matchedNodesIds); virtual void finalize(const Ptr& net, const Ptr& fusedNode, diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 15f79c8769..a8f4058d50 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -125,8 +125,13 @@ public: virtual void removeNode(int idx) CV_OVERRIDE { - CV_Assert(idx >= numInputs + numInitializers); - net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1); + if (idx >= numInputs + numInitializers) + net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1); + } + + virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE + { + return type == "Add" || type == "Mul" || type == "Equal" || type == "Max"; } private: @@ -134,6 +139,25 @@ private: opencv_onnx::GraphProto& net; }; +static Mat extractConstant(const Ptr& net, int node_id, int input_id) +{ + auto onnx_net = net.dynamicCast(); + int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); + if (initializer_id != -1) + { + return onnx_net->getMatFromInitializer(initializer_id); + } + else + { + const Ptr node = net->getNode(node_id); + int constant_id = Subgraph::getInputNodeId(net, node, input_id); + Ptr constant_ptr = net->getNode(constant_id); + opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; + opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); + return getMatFromTensor(constant_proto); + } +} + /* Fusion for Gelu. Graph before fusion: @@ -151,54 +175,32 @@ public: GeluSubGraph() { int input = addNodeToMatch(""); - int div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ ); + div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ ); int erf = addNodeToMatch("Erf", div); - int add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ ); + add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ ); int mul = addNodeToMatch("Mul", input, add); - addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ; + mul2 = addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ; setFusedNode("Gelu", input); } - static float extractConstant(const Ptr& net, int node_id, int input_id) - { - auto onnx_net = net.dynamicCast(); - int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); - if (initializer_id != -1) - { - Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); - return *const_mat.ptr(); - } - else - { - const Ptr node = net->getNode(node_id); - int constant_id = getInputNodeId(net, node, input_id); - Ptr constant_ptr = net->getNode(constant_id); - opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; - opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); - Mat constant_mat = getMatFromTensor(constant_proto); - return *constant_mat.ptr(); - } - } - virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (Subgraph::match(net, nodeId, matchedNodesIds)) { // Check Div[B=sqrt(2)] - float divisor = extractConstant(net, matchedNodesIds[0], 1); + float divisor = extractConstant(net, matchedNodesIds[div], 1).at(0); if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits::epsilon()) return false; // Check Add[B=1] - float add_const = extractConstant(net, matchedNodesIds[2], 1); + float add_const = extractConstant(net, matchedNodesIds[add], 1).at(0); if (std::fabs(add_const - 1.f) >= std::numeric_limits::epsilon()) return false; // Check Mul[B=0.5] - float mul_const = extractConstant(net, matchedNodesIds[4], 1); + float mul_const = extractConstant(net, matchedNodesIds[mul2], 1).at(0); if (std::fabs(mul_const - 0.5f) >= std::numeric_limits::epsilon()) return false; @@ -206,6 +208,9 @@ public: } return false; } + +private: + int div, add, mul2; }; /* Fusion for GeluApproximation. @@ -229,61 +234,39 @@ public: int input = addNodeToMatch(""); int mul0 = addNodeToMatch("Mul", input, input); int mul1 = addNodeToMatch("Mul", input, mul0); - int mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1); + mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1); int add0 = addNodeToMatch("Add", input, mul2); - int mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0); + mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0); int tanh = addNodeToMatch("Tanh", mul3); - int add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh); + add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh); int mul4 = addNodeToMatch("Mul", input, add1); - addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4); + mul5 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4); setFusedNode("GeluApproximation", input); } - static float extractConstant(const Ptr& net, int node_id, int input_id) - { - auto onnx_net = net.dynamicCast(); - int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); - if (initializer_id != -1) - { - Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); - return *const_mat.ptr(); - } - else - { - const Ptr node = net->getNode(node_id); - int constant_id = getInputNodeId(net, node, input_id); - Ptr constant_ptr = net->getNode(constant_id); - opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; - opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); - Mat constant_mat = getMatFromTensor(constant_proto); - return *constant_mat.ptr(); - } - } - virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (Subgraph::match(net, nodeId, matchedNodesIds)) { // Check Mul[A=0.044714998453855515] - float coef = extractConstant(net, matchedNodesIds[2], 0); + float coef = extractConstant(net, matchedNodesIds[mul2], 0).at(0); if (coef - 0.044714998453855515 >= 1e-6) return false; // Check Mul[A=sqrt(2/pie)] - float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0); + float sqrt_2_pie = extractConstant(net, matchedNodesIds[mul3], 0).at(0); if (sqrt_2_pie - 0.7978845834732056 >= 1e-6) return false; // Check Add[A=1] - float add_const = extractConstant(net, matchedNodesIds[6], 0); + float add_const = extractConstant(net, matchedNodesIds[add1], 0).at(0); if (add_const - 1.f >= 1e-6) return false; // Check Mul[A=0.5] - float mul_const = extractConstant(net, matchedNodesIds[8], 0); + float mul_const = extractConstant(net, matchedNodesIds[mul5], 0).at(0); if (mul_const - 0.5f >= 1e-6) return false; @@ -291,6 +274,9 @@ public: } return false; } + +private: + int mul2, mul3, add1, mul5; }; /* Fusion for LayerNormalization. @@ -313,43 +299,22 @@ public: LayerNormSubGraph() : axis(-1), epsilon(1e-5) { int input = addNodeToMatch(""); - int mean = addNodeToMatch("ReduceMean", input); + mean = addNodeToMatch("ReduceMean", input); int sub = addNodeToMatch("Sub", input, mean); - int pow = addNodeToMatch("Pow", sub, addNodeToMatch("")); - int mean1 = addNodeToMatch("ReduceMean", pow); - int add = addNodeToMatch("Add", mean1, addNodeToMatch("")); + pow = addNodeToMatch("Pow", sub, addNodeToMatch("")); + mean1 = addNodeToMatch("ReduceMean", pow); + add = addNodeToMatch("Add", mean1, addNodeToMatch("")); int sqrt = addNodeToMatch("Sqrt", add); int div = addNodeToMatch("Div", sub, sqrt); - int mul = addNodeToMatch("Mul", div, addNodeToMatch("")); - addNodeToMatch("Add", mul, addNodeToMatch("")); + mul = addNodeToMatch("Mul", div, addNodeToMatch("")); + bias = addNodeToMatch("Add", mul, addNodeToMatch("")); setFusedNode("LayerNormalization", input); } - static float extractConstant(const Ptr& net, int node_id, int input_id) - { - auto onnx_net = net.dynamicCast(); - int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); - if (initializer_id != -1) // initializer - { - Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); - return *const_mat.ptr(); - } - else - { - const Ptr node = net->getNode(node_id); - int constant_id = getInputNodeId(net, node, input_id); - Ptr constant_ptr = net->getNode(constant_id); - opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; - opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); - Mat constant_mat = getMatFromTensor(constant_proto); - return *constant_mat.ptr(); - } - } - static float extractAxis(const Ptr& net, int node_id) { Ptr mean_ptr = net->getNode(node_id); @@ -381,25 +346,24 @@ public: } virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (Subgraph::match(net, nodeId, matchedNodesIds)) { - float pow_exp = extractConstant(net, matchedNodesIds[2], 1); + float pow_exp = extractConstant(net, matchedNodesIds[pow], 1).at(0); if (pow_exp - 2 > 1e-5) // not pow(2) return false; - int axis_mean1 = extractAxis(net, matchedNodesIds[0]); - int axis_mean2 = extractAxis(net, matchedNodesIds[3]); + int axis_mean1 = extractAxis(net, matchedNodesIds[mean]); + int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]); if (axis_mean1 != axis_mean2) return false; axis = axis_mean1; - epsilon = extractConstant(net, matchedNodesIds[4], 1); + epsilon = extractConstant(net, matchedNodesIds[add], 1).at(0); - weight_name = getInputName(net, matchedNodesIds[7], 1); - bias_name = getInputName(net, matchedNodesIds[8], 1); + weight_name = getInputName(net, matchedNodesIds[mul], 1); + bias_name = getInputName(net, matchedNodesIds[bias], 1); return true; } @@ -429,6 +393,7 @@ protected: float epsilon; std::string weight_name; std::string bias_name; + int pow, mean, mean1, add, mul, bias; }; class SoftMaxSubgraphBase : public Subgraph @@ -437,10 +402,9 @@ public: SoftMaxSubgraphBase() : axis(1), id(-1) {} virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (Subgraph::match(net, nodeId, matchedNodesIds)) { CV_Assert(id >= 0 && id < matchedNodesIds.size()); Ptr sum = net->getNode(matchedNodesIds[id]); @@ -485,7 +449,7 @@ public: int inpExp = addNodeToMatch("Exp", input); int sum = addNodeToMatch("ReduceSum", inpExp); - id = 1; + id = sum; addNodeToMatch("Div", inpExp, sum); setFusedNode("Softmax", input); @@ -498,7 +462,7 @@ public: int input = addNodeToMatch(""); int reducemax = addNodeToMatch("ReduceMax", input); - id = 0; + id = reducemax; int sub = addNodeToMatch("Sub", input, reducemax); int exp = addNodeToMatch("Exp", sub); @@ -516,7 +480,7 @@ public: int input = addNodeToMatch(""); int reducemax = addNodeToMatch("ReduceMax", input); - id = 0; + id = reducemax; int sub_1 = addNodeToMatch("Sub", input, reducemax); int exp = addNodeToMatch("Exp", sub_1); @@ -533,18 +497,17 @@ public: HardSwishSubgraph() { int input = addNodeToMatch(""); - int hardSigmoid = addNodeToMatch("HardSigmoid", input); - addNodeToMatch("Mul", input, hardSigmoid); + hardSigmoidId = addNodeToMatch("HardSigmoid", input); + addNodeToMatch("Mul", input, hardSigmoidId); setFusedNode("HardSwish", input); } virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (Subgraph::match(net, nodeId, matchedNodesIds)) { - Ptr hardSigmoid = net->getNode(matchedNodesIds[0]); + Ptr hardSigmoid = net->getNode(matchedNodesIds[hardSigmoidId]); opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast()->node; uint8_t matched = 0; @@ -561,6 +524,9 @@ public: } return false; } + +private: + int hardSigmoidId; }; class CeluSubgraph : public Subgraph @@ -569,9 +535,9 @@ public: CeluSubgraph() : alpha(1.f) { int input = addNodeToMatch(""); - int div = addNodeToMatch("Div", input, addNodeToMatch("")); - int elu = addNodeToMatch("Elu", div); - addNodeToMatch("Mul", addNodeToMatch(""), elu); + div = addNodeToMatch("Div", input, addNodeToMatch("")); + elu = addNodeToMatch("Elu", div); + mul = addNodeToMatch("Mul", addNodeToMatch(""), elu); setFusedNode("Celu", input); } @@ -587,16 +553,15 @@ public: } virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (Subgraph::match(net, nodeId, matchedNodesIds)) { - float alpha_div = extractAlpha(net, matchedNodesIds[0], 1); - float alpha_mul = extractAlpha(net, matchedNodesIds[2], 0); + float alpha_div = extractAlpha(net, matchedNodesIds[div], 1); + float alpha_mul = extractAlpha(net, matchedNodesIds[mul], 0); float alpha_elu = 1.f; - Ptr elu_ptr = net->getNode(matchedNodesIds[1]); + Ptr elu_ptr = net->getNode(matchedNodesIds[elu]); opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast()->node; for (int i = 0; i < elu_node->attribute_size(); i++) @@ -625,18 +590,18 @@ public: protected: float alpha; + int div, mul, elu; }; class NormalizeSubgraphBase : public Subgraph { public: - NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {} + NormalizeSubgraphBase(int _normNodeOrder = 1) : axis(1), normNodeOrder(_normNodeOrder) {} virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (Subgraph::match(net, nodeId, matchedNodesIds)) { Ptr norm = net->getNode(matchedNodesIds[normNodeOrder]); opencv_onnx::NodeProto* node = norm.dynamicCast()->node; @@ -725,7 +690,7 @@ public: class NormalizeSubgraph3 : public NormalizeSubgraphBase { public: - NormalizeSubgraph3() : NormalizeSubgraphBase(1) + NormalizeSubgraph3() : NormalizeSubgraphBase(3) { int input = addNodeToMatch(""); int power = addNodeToMatch("Constant"); @@ -743,7 +708,7 @@ public: class NormalizeSubgraph4 : public NormalizeSubgraphBase { public: - NormalizeSubgraph4() : NormalizeSubgraphBase(1) + NormalizeSubgraph4() : NormalizeSubgraphBase(2) { int input = addNodeToMatch(""); int mul = addNodeToMatch("Mul", input, input); @@ -760,7 +725,7 @@ public: class NormalizeSubgraph5 : public NormalizeSubgraphBase { public: - NormalizeSubgraph5() : NormalizeSubgraphBase(1) + NormalizeSubgraph5() : NormalizeSubgraphBase(2) { int input = addNodeToMatch(""); int mul = addNodeToMatch("Mul", input, input); @@ -781,25 +746,24 @@ public: { int input = addNodeToMatch(""); int index = addNodeToMatch("Constant"); - int gather = addNodeToMatch("Gather", input, index); - addNodeToMatch("Cast", gather); + gather = addNodeToMatch("Gather", input, index); + cast = addNodeToMatch("Cast", gather); setFusedNode("Gather", input, index); } virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); + bool retVal = Subgraph::match(net, nodeId, matchedNodesIds); size_t matchedNodesNum = matchedNodesIds.size(); // Now we check if merging can be made for these Gather and Cast nodes if (!retVal || matchedNodesNum < 2) return retVal; else { - int nodeToMatch = matchedNodesIds[matchedNodesNum - 1]; + int nodeToMatch = matchedNodesIds[cast]; const Ptr node = net->getNode(nodeToMatch); if (node->getType() == "Cast") { - int inpNodeId = matchedNodesIds[matchedNodesNum - 2]; + int inpNodeId = matchedNodesIds[gather]; const Ptr inpNode = net->getNode(inpNodeId); if (inpNode->getType() == "Gather") { int numNodes = net->getNumNodes(); @@ -819,6 +783,9 @@ public: } return retVal; } + +private: + int cast, gather; }; /* Constant folding shape for Expand. @@ -838,12 +805,12 @@ public: { int input = addNodeToMatch(""); int values = addNodeToMatch(""); - int init = addNodeToMatch("ConstantOfShape", values); + init = addNodeToMatch("ConstantOfShape", values); int coeff = addNodeToMatch("Constant"); - int mul = addNodeToMatch("Mul", init, coeff); + mul = addNodeToMatch("Mul", init, coeff); int shape = addNodeToMatch("Constant"); - int condition = addNodeToMatch("Equal", shape, mul); - int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant")); + condition = addNodeToMatch("Equal", shape, mul); + where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant")); addNodeToMatch("Expand", input, where); setFusedNode("Expand", input, shape); } @@ -872,53 +839,28 @@ public: return 0; } - static std::vector extractConstant(const Ptr& net, int node_id, int input_id) - { - auto onnx_net = net.dynamicCast(); - int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); - Mat mat_constant; - if (initializer_id != -1) // initializer - { - mat_constant = onnx_net->getMatFromInitializer(initializer_id); - } - else - { - const Ptr node = net->getNode(node_id); - int constant_id = getInputNodeId(net, node, input_id); - Ptr constant_ptr = net->getNode(constant_id); - opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast()->node; - opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); - mat_constant = getMatFromTensor(constant_proto); - } - - std::vector retvals{mat_constant.begin(), mat_constant.end()}; - return retvals; - } - virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE { - if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) { + std::vector& matchedNodesIds) CV_OVERRIDE { + if (Subgraph::match(net, nodeId, matchedNodesIds)) { int64_t value_ConstantOfShape; - if (!extractValue(net, matchedNodesIds[0], value_ConstantOfShape)) { + if (!extractValue(net, matchedNodesIds[init], value_ConstantOfShape)) { return false; } - std::vector input_ConstantOfShape = extractConstant(net, matchedNodesIds[0], 0); + std::vector input_ConstantOfShape = extractConstant(net, matchedNodesIds[init], 0); if (input_ConstantOfShape.size() != static_cast(1)) { return false; } - - auto B_Mul = extractConstant(net, matchedNodesIds[1], 1); + std::vector B_Mul = extractConstant(net, matchedNodesIds[mul], 1); if (B_Mul.size() != static_cast(1)) { return false; } - auto A_Equal = extractConstant(net, matchedNodesIds[2], 0); + std::vector A_Equal = extractConstant(net, matchedNodesIds[condition], 0); if (A_Equal.size() != static_cast(input_ConstantOfShape[0])) { return false; } - auto Y_Where = extractConstant(net, matchedNodesIds[3], 2); + std::vector Y_Where = extractConstant(net, matchedNodesIds[where], 2); if (Y_Where.size() != A_Equal.size()) { return false; } @@ -969,6 +911,9 @@ public: protected: std::vector shape; + +private: + int init, mul, condition, where; }; class MishSubgraph : public Subgraph @@ -979,7 +924,7 @@ public: int input = addNodeToMatch(""); int softplus = addNodeToMatch("Softplus", input); int tanh = addNodeToMatch("Tanh", softplus); - addNodeToMatch("Mul", input, tanh); + addNodeToMatch("Mul", tanh, input); setFusedNode("Mish", input); } }; @@ -999,20 +944,6 @@ public: } }; -class SoftplusSubgraph2: public Subgraph -{ -public: - SoftplusSubgraph2() - { - int input = addNodeToMatch(""); - int exp = addNodeToMatch("Exp", input); - int addVal = addNodeToMatch(""); - int add = addNodeToMatch("Add", exp, addVal); - addNodeToMatch("Log", add); - setFusedNode("Softplus", input); - } -}; - class MulCastSubgraph : public Subgraph { public: @@ -1248,7 +1179,6 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); - subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 5531b28111..8ba1963512 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -98,6 +98,14 @@ public: net.mutable_node()->DeleteSubrange(idx, 1); } + virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE + { + return type == "Add" || type == "Sum" || + type == "Mul" || type == "Prod" || + type == "Max" || type == "Maximum" || type == "Minimum" || + type == "Mean" || type == "SquaredDifference"; + } + tensorflow::GraphDef& net; }; @@ -282,24 +290,26 @@ public: { int input = addNodeToMatch(""); int relu = addNodeToMatch("Relu", input); - int maxValue = addNodeToMatch("Const"); + maxValueId = addNodeToMatch("Const"); int clipValue = addNodeToMatch("Const"); - int minimum = addNodeToMatch("Minimum", relu, maxValue); + int minimum = addNodeToMatch("Minimum", relu, maxValueId); addNodeToMatch("Maximum", minimum, clipValue); setFusedNode("Relu6", input); } virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { - if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) + if (!Subgraph::match(net, nodeId, matchedNodesIds)) return false; - tensorflow::NodeDef* node = net->getNode(matchedNodesIds.front() + 1).dynamicCast()->node; + tensorflow::NodeDef* node = net->getNode(matchedNodesIds[maxValueId]).dynamicCast()->node; Mat maxValue = getTensorContent(node->attr().at("value").tensor()); return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at(0) == 6; } + +private: + int maxValueId; }; // Keras' reshape stores output shape in separate Const nodes by one value. @@ -328,15 +338,14 @@ public: } virtual bool match(const Ptr& net, int nodeId, - std::vector& matchedNodesIds, - std::vector& targetNodesIds) CV_OVERRIDE + std::vector& matchedNodesIds) CV_OVERRIDE { Ptr node = net->getNode(nodeId); if (node->getNumInputs() == 0) return false; inpName = node->getInputName(0); - return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); + return Subgraph::match(net, nodeId, matchedNodesIds); }