Merge pull request #24483 from dkurt:dnn_fusion_commutative_ops

Commutative rules for DNN subgraphs fusion #24483

### Pull Request Readiness Checklist

related: https://github.com/opencv/opencv/pull/24463#issuecomment-1783033931

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24519/head
Dmitry Kurtaev 1 year ago committed by GitHub
parent 41c335e5a5
commit b7ec2ebb55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 68
      modules/dnn/src/graph_simplifier.cpp
  2. 8
      modules/dnn/src/graph_simplifier.hpp
  3. 304
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  4. 27
      modules/dnn/src/tensorflow/tf_graph_simplifier.cpp

@ -77,14 +77,14 @@ int Subgraph::getInputNodeId(const Ptr<ImportGraphWrapper>& net,
} }
bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId, bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds)
std::vector<int>& targetNodesIds)
{ {
matchedNodesIds.clear(); matchedNodesIds.clear();
targetNodesIds.clear();
std::queue<int> nodesToMatch; std::queue<int> nodesToMatch;
std::queue<int> targetNodes; std::queue<int> targetNodes;
std::vector<std::pair<int, int> > matchings;
matchings.reserve(nodes.size());
nodesToMatch.push(nodeId); nodesToMatch.push(nodeId);
targetNodes.push(nodes.size() - 1); targetNodes.push(nodes.size() - 1);
while (!nodesToMatch.empty()) while (!nodesToMatch.empty())
@ -94,51 +94,63 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
nodesToMatch.pop(); nodesToMatch.pop();
targetNodes.pop(); targetNodes.pop();
if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) != if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair<int, int>& match){ return match.first == targetNodeId; }) !=
matchedNodesIds.end()) matchings.end())
continue; continue;
// Empty placeholder matches with any input type
if (nodes[targetNodeId].empty()) {
matchings.push_back({targetNodeId, nodeToMatch});
continue;
}
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch); const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
if (node->getType() != nodes[targetNodeId]) if (node->getType() != nodes[targetNodeId])
return false; continue;
std::vector<int>& inputNodes = inputs[targetNodeId]; std::vector<int>& inputNodes = inputs[targetNodeId];
if (inputNodes.size() != node->getNumInputs()) if (inputNodes.size() != node->getNumInputs())
return false; continue;
bool isCommutative = net->isCommutativeOp(node->getType());
for (int j = 0; j < inputNodes.size(); ++j) for (int j = 0; j < inputNodes.size(); ++j)
{ {
if (nodes[inputNodes[j]].empty() || node->getInputName(j).empty()) // Unknown input node type. // Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
if (node->getInputName(j).empty())
continue; continue;
nodeId = getInputNodeId(net, node, j); nodeId = getInputNodeId(net, node, j);
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId); const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
if (inpNode->getType() != "Const" && inpNode->getType() != "Constant") if (isCommutative)
{
for (int i = 0; i < inputNodes.size(); ++i)
{
nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[i]);
}
}
else
{ {
nodesToMatch.push(nodeId); nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[j]); targetNodes.push(inputNodes[j]);
} }
else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant")
return false;
} }
matchedNodesIds.push_back(nodeToMatch); matchings.push_back({targetNodeId, nodeToMatch});
targetNodesIds.push_back(targetNodeId);
} }
if (matchings.size() != nodes.size())
return false;
const int n = matchedNodesIds.size(); // Sort matched by pattern nodes order.
std::vector<std::pair<int, int> > elements(n); std::sort(matchings.begin(), matchings.end());
for (int i = 0; i < n; ++i) matchedNodesIds.resize(matchings.size());
elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]); for (int i = 0; i < matchings.size(); ++i)
std::sort(elements.begin(), elements.end());
for (int i = 0; i < n; ++i)
{ {
matchedNodesIds[i] = elements[i].first; matchedNodesIds[i] = matchings[i].second;
targetNodesIds[i] = elements[i].second;
} }
return true; return true;
} }
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds, void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds)
const std::vector<int>& targetNodesIds)
{ {
// Extract names of input nodes. // Extract names of input nodes.
std::vector<std::string> inputsNames(fusedNodeInputs.size()); std::vector<std::string> inputsNames(fusedNodeInputs.size());
@ -149,9 +161,9 @@ void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j) for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
{ {
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]); Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]);
std::vector<int>& inpIndices = inputs[targetNodesIds[j]]; std::vector<int>& inpIndices = inputs[j];
CV_Assert(node->getNumInputs() == inpIndices.size()); CV_Assert(inpIndices.empty() || node->getNumInputs() == inpIndices.size());
for (int k = 0; k < inpIndices.size(); ++k) for (int k = 0; k < inpIndices.size(); ++k)
{ {
if (inpIndices[k] == fusedNodeInputs[i]) if (inpIndices[k] == fusedNodeInputs[i])
@ -187,15 +199,15 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
const std::vector<Ptr<Subgraph> >& patterns) const std::vector<Ptr<Subgraph> >& patterns)
{ {
int numNodes = net->getNumNodes(); int numNodes = net->getNumNodes();
std::vector<int> matchedNodesIds, targetNodesIds; std::vector<int> matchedNodesIds;
std::vector<int> nodesToRemove; std::vector<int> nodesToRemove;
for (int j = 0; j < patterns.size(); ++j) for (int j = 0; j < patterns.size(); ++j)
{ {
for (int i = 0; i < numNodes; ++i) for (int i = 0; i < numNodes; ++i)
{ {
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) if (patterns[j]->match(net, i, matchedNodesIds))
{ {
patterns[j]->replace(net, matchedNodesIds, targetNodesIds); patterns[j]->replace(net, matchedNodesIds);
// Remove matched nodes except the last one. // Remove matched nodes except the last one.
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1); nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
} }

@ -44,6 +44,8 @@ public:
virtual std::string getOutputName(int nodeId, int outId) const = 0; virtual std::string getOutputName(int nodeId, int outId) const = 0;
virtual void removeNode(int idx) = 0; virtual void removeNode(int idx) = 0;
virtual bool isCommutativeOp(const std::string& type) const = 0;
}; };
class Subgraph // Interface to match and replace subgraphs. class Subgraph // Interface to match and replace subgraphs.
@ -75,12 +77,10 @@ public:
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused. // Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused. // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds);
std::vector<int>& targetNodesIds);
// Fuse matched subgraph. // Fuse matched subgraph.
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds, void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds);
const std::vector<int>& targetNodesIds);
virtual void finalize(const Ptr<ImportGraphWrapper>& net, virtual void finalize(const Ptr<ImportGraphWrapper>& net,
const Ptr<ImportNodeWrapper>& fusedNode, const Ptr<ImportNodeWrapper>& fusedNode,

@ -125,8 +125,13 @@ public:
virtual void removeNode(int idx) CV_OVERRIDE virtual void removeNode(int idx) CV_OVERRIDE
{ {
CV_Assert(idx >= numInputs + numInitializers); if (idx >= numInputs + numInitializers)
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1); net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
}
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
{
return type == "Add" || type == "Mul" || type == "Equal" || type == "Max";
} }
private: private:
@ -134,6 +139,25 @@ private:
opencv_onnx::GraphProto& net; opencv_onnx::GraphProto& net;
}; };
static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
return onnx_net->getMatFromInitializer(initializer_id);
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = Subgraph::getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
return getMatFromTensor(constant_proto);
}
}
/* Fusion for Gelu. /* Fusion for Gelu.
Graph before fusion: Graph before fusion:
@ -151,54 +175,32 @@ public:
GeluSubGraph() GeluSubGraph()
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ ); div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ );
int erf = addNodeToMatch("Erf", div); int erf = addNodeToMatch("Erf", div);
int add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ ); add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ );
int mul = addNodeToMatch("Mul", input, add); int mul = addNodeToMatch("Mul", input, add);
addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ; mul2 = addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ;
setFusedNode("Gelu", input); setFusedNode("Gelu", input);
} }
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
Mat constant_mat = getMatFromTensor(constant_proto);
return *constant_mat.ptr<float>();
}
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds))
{ {
// Check Div[B=sqrt(2)] // Check Div[B=sqrt(2)]
float divisor = extractConstant(net, matchedNodesIds[0], 1); float divisor = extractConstant(net, matchedNodesIds[div], 1).at<float>(0);
if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon()) if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon())
return false; return false;
// Check Add[B=1] // Check Add[B=1]
float add_const = extractConstant(net, matchedNodesIds[2], 1); float add_const = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon()) if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon())
return false; return false;
// Check Mul[B=0.5] // Check Mul[B=0.5]
float mul_const = extractConstant(net, matchedNodesIds[4], 1); float mul_const = extractConstant(net, matchedNodesIds[mul2], 1).at<float>(0);
if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon()) if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon())
return false; return false;
@ -206,6 +208,9 @@ public:
} }
return false; return false;
} }
private:
int div, add, mul2;
}; };
/* Fusion for GeluApproximation. /* Fusion for GeluApproximation.
@ -229,61 +234,39 @@ public:
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int mul0 = addNodeToMatch("Mul", input, input); int mul0 = addNodeToMatch("Mul", input, input);
int mul1 = addNodeToMatch("Mul", input, mul0); int mul1 = addNodeToMatch("Mul", input, mul0);
int mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1); mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1);
int add0 = addNodeToMatch("Add", input, mul2); int add0 = addNodeToMatch("Add", input, mul2);
int mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0); mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0);
int tanh = addNodeToMatch("Tanh", mul3); int tanh = addNodeToMatch("Tanh", mul3);
int add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh); add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh);
int mul4 = addNodeToMatch("Mul", input, add1); int mul4 = addNodeToMatch("Mul", input, add1);
addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4); mul5 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4);
setFusedNode("GeluApproximation", input); setFusedNode("GeluApproximation", input);
} }
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
Mat constant_mat = getMatFromTensor(constant_proto);
return *constant_mat.ptr<float>();
}
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds))
{ {
// Check Mul[A=0.044714998453855515] // Check Mul[A=0.044714998453855515]
float coef = extractConstant(net, matchedNodesIds[2], 0); float coef = extractConstant(net, matchedNodesIds[mul2], 0).at<float>(0);
if (coef - 0.044714998453855515 >= 1e-6) if (coef - 0.044714998453855515 >= 1e-6)
return false; return false;
// Check Mul[A=sqrt(2/pie)] // Check Mul[A=sqrt(2/pie)]
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0); float sqrt_2_pie = extractConstant(net, matchedNodesIds[mul3], 0).at<float>(0);
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6) if (sqrt_2_pie - 0.7978845834732056 >= 1e-6)
return false; return false;
// Check Add[A=1] // Check Add[A=1]
float add_const = extractConstant(net, matchedNodesIds[6], 0); float add_const = extractConstant(net, matchedNodesIds[add1], 0).at<float>(0);
if (add_const - 1.f >= 1e-6) if (add_const - 1.f >= 1e-6)
return false; return false;
// Check Mul[A=0.5] // Check Mul[A=0.5]
float mul_const = extractConstant(net, matchedNodesIds[8], 0); float mul_const = extractConstant(net, matchedNodesIds[mul5], 0).at<float>(0);
if (mul_const - 0.5f >= 1e-6) if (mul_const - 0.5f >= 1e-6)
return false; return false;
@ -291,6 +274,9 @@ public:
} }
return false; return false;
} }
private:
int mul2, mul3, add1, mul5;
}; };
/* Fusion for LayerNormalization. /* Fusion for LayerNormalization.
@ -313,43 +299,22 @@ public:
LayerNormSubGraph() : axis(-1), epsilon(1e-5) LayerNormSubGraph() : axis(-1), epsilon(1e-5)
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int mean = addNodeToMatch("ReduceMean", input); mean = addNodeToMatch("ReduceMean", input);
int sub = addNodeToMatch("Sub", input, mean); int sub = addNodeToMatch("Sub", input, mean);
int pow = addNodeToMatch("Pow", sub, addNodeToMatch("")); pow = addNodeToMatch("Pow", sub, addNodeToMatch(""));
int mean1 = addNodeToMatch("ReduceMean", pow); mean1 = addNodeToMatch("ReduceMean", pow);
int add = addNodeToMatch("Add", mean1, addNodeToMatch("")); add = addNodeToMatch("Add", mean1, addNodeToMatch(""));
int sqrt = addNodeToMatch("Sqrt", add); int sqrt = addNodeToMatch("Sqrt", add);
int div = addNodeToMatch("Div", sub, sqrt); int div = addNodeToMatch("Div", sub, sqrt);
int mul = addNodeToMatch("Mul", div, addNodeToMatch("")); mul = addNodeToMatch("Mul", div, addNodeToMatch(""));
addNodeToMatch("Add", mul, addNodeToMatch("")); bias = addNodeToMatch("Add", mul, addNodeToMatch(""));
setFusedNode("LayerNormalization", input); setFusedNode("LayerNormalization", input);
} }
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1) // initializer
{
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
Mat constant_mat = getMatFromTensor(constant_proto);
return *constant_mat.ptr<float>();
}
}
static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id) static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
{ {
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id); Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
@ -381,25 +346,24 @@ public:
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds))
{ {
float pow_exp = extractConstant(net, matchedNodesIds[2], 1); float pow_exp = extractConstant(net, matchedNodesIds[pow], 1).at<float>(0);
if (pow_exp - 2 > 1e-5) // not pow(2) if (pow_exp - 2 > 1e-5) // not pow(2)
return false; return false;
int axis_mean1 = extractAxis(net, matchedNodesIds[0]); int axis_mean1 = extractAxis(net, matchedNodesIds[mean]);
int axis_mean2 = extractAxis(net, matchedNodesIds[3]); int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]);
if (axis_mean1 != axis_mean2) if (axis_mean1 != axis_mean2)
return false; return false;
axis = axis_mean1; axis = axis_mean1;
epsilon = extractConstant(net, matchedNodesIds[4], 1); epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
weight_name = getInputName(net, matchedNodesIds[7], 1); weight_name = getInputName(net, matchedNodesIds[mul], 1);
bias_name = getInputName(net, matchedNodesIds[8], 1); bias_name = getInputName(net, matchedNodesIds[bias], 1);
return true; return true;
} }
@ -429,6 +393,7 @@ protected:
float epsilon; float epsilon;
std::string weight_name; std::string weight_name;
std::string bias_name; std::string bias_name;
int pow, mean, mean1, add, mul, bias;
}; };
class SoftMaxSubgraphBase : public Subgraph class SoftMaxSubgraphBase : public Subgraph
@ -437,10 +402,9 @@ public:
SoftMaxSubgraphBase() : axis(1), id(-1) {} SoftMaxSubgraphBase() : axis(1), id(-1) {}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds))
{ {
CV_Assert(id >= 0 && id < matchedNodesIds.size()); CV_Assert(id >= 0 && id < matchedNodesIds.size());
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]); Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
@ -485,7 +449,7 @@ public:
int inpExp = addNodeToMatch("Exp", input); int inpExp = addNodeToMatch("Exp", input);
int sum = addNodeToMatch("ReduceSum", inpExp); int sum = addNodeToMatch("ReduceSum", inpExp);
id = 1; id = sum;
addNodeToMatch("Div", inpExp, sum); addNodeToMatch("Div", inpExp, sum);
setFusedNode("Softmax", input); setFusedNode("Softmax", input);
@ -498,7 +462,7 @@ public:
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int reducemax = addNodeToMatch("ReduceMax", input); int reducemax = addNodeToMatch("ReduceMax", input);
id = 0; id = reducemax;
int sub = addNodeToMatch("Sub", input, reducemax); int sub = addNodeToMatch("Sub", input, reducemax);
int exp = addNodeToMatch("Exp", sub); int exp = addNodeToMatch("Exp", sub);
@ -516,7 +480,7 @@ public:
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int reducemax = addNodeToMatch("ReduceMax", input); int reducemax = addNodeToMatch("ReduceMax", input);
id = 0; id = reducemax;
int sub_1 = addNodeToMatch("Sub", input, reducemax); int sub_1 = addNodeToMatch("Sub", input, reducemax);
int exp = addNodeToMatch("Exp", sub_1); int exp = addNodeToMatch("Exp", sub_1);
@ -533,18 +497,17 @@ public:
HardSwishSubgraph() HardSwishSubgraph()
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int hardSigmoid = addNodeToMatch("HardSigmoid", input); hardSigmoidId = addNodeToMatch("HardSigmoid", input);
addNodeToMatch("Mul", input, hardSigmoid); addNodeToMatch("Mul", input, hardSigmoidId);
setFusedNode("HardSwish", input); setFusedNode("HardSwish", input);
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds))
{ {
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[0]); Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[hardSigmoidId]);
opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast<ONNXNodeWrapper>()->node; opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast<ONNXNodeWrapper>()->node;
uint8_t matched = 0; uint8_t matched = 0;
@ -561,6 +524,9 @@ public:
} }
return false; return false;
} }
private:
int hardSigmoidId;
}; };
class CeluSubgraph : public Subgraph class CeluSubgraph : public Subgraph
@ -569,9 +535,9 @@ public:
CeluSubgraph() : alpha(1.f) CeluSubgraph() : alpha(1.f)
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int div = addNodeToMatch("Div", input, addNodeToMatch("")); div = addNodeToMatch("Div", input, addNodeToMatch(""));
int elu = addNodeToMatch("Elu", div); elu = addNodeToMatch("Elu", div);
addNodeToMatch("Mul", addNodeToMatch(""), elu); mul = addNodeToMatch("Mul", addNodeToMatch(""), elu);
setFusedNode("Celu", input); setFusedNode("Celu", input);
} }
@ -587,16 +553,15 @@ public:
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds))
{ {
float alpha_div = extractAlpha(net, matchedNodesIds[0], 1); float alpha_div = extractAlpha(net, matchedNodesIds[div], 1);
float alpha_mul = extractAlpha(net, matchedNodesIds[2], 0); float alpha_mul = extractAlpha(net, matchedNodesIds[mul], 0);
float alpha_elu = 1.f; float alpha_elu = 1.f;
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[1]); Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[elu]);
opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast<ONNXNodeWrapper>()->node; opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = 0; i < elu_node->attribute_size(); i++) for (int i = 0; i < elu_node->attribute_size(); i++)
@ -625,18 +590,18 @@ public:
protected: protected:
float alpha; float alpha;
int div, mul, elu;
}; };
class NormalizeSubgraphBase : public Subgraph class NormalizeSubgraphBase : public Subgraph
{ {
public: public:
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {} NormalizeSubgraphBase(int _normNodeOrder = 1) : axis(1), normNodeOrder(_normNodeOrder) {}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (Subgraph::match(net, nodeId, matchedNodesIds))
{ {
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]); Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node; opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
@ -725,7 +690,7 @@ public:
class NormalizeSubgraph3 : public NormalizeSubgraphBase class NormalizeSubgraph3 : public NormalizeSubgraphBase
{ {
public: public:
NormalizeSubgraph3() : NormalizeSubgraphBase(1) NormalizeSubgraph3() : NormalizeSubgraphBase(3)
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int power = addNodeToMatch("Constant"); int power = addNodeToMatch("Constant");
@ -743,7 +708,7 @@ public:
class NormalizeSubgraph4 : public NormalizeSubgraphBase class NormalizeSubgraph4 : public NormalizeSubgraphBase
{ {
public: public:
NormalizeSubgraph4() : NormalizeSubgraphBase(1) NormalizeSubgraph4() : NormalizeSubgraphBase(2)
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int mul = addNodeToMatch("Mul", input, input); int mul = addNodeToMatch("Mul", input, input);
@ -760,7 +725,7 @@ public:
class NormalizeSubgraph5 : public NormalizeSubgraphBase class NormalizeSubgraph5 : public NormalizeSubgraphBase
{ {
public: public:
NormalizeSubgraph5() : NormalizeSubgraphBase(1) NormalizeSubgraph5() : NormalizeSubgraphBase(2)
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int mul = addNodeToMatch("Mul", input, input); int mul = addNodeToMatch("Mul", input, input);
@ -781,25 +746,24 @@ public:
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int index = addNodeToMatch("Constant"); int index = addNodeToMatch("Constant");
int gather = addNodeToMatch("Gather", input, index); gather = addNodeToMatch("Gather", input, index);
addNodeToMatch("Cast", gather); cast = addNodeToMatch("Cast", gather);
setFusedNode("Gather", input, index); setFusedNode("Gather", input, index);
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); bool retVal = Subgraph::match(net, nodeId, matchedNodesIds);
size_t matchedNodesNum = matchedNodesIds.size(); size_t matchedNodesNum = matchedNodesIds.size();
// Now we check if merging can be made for these Gather and Cast nodes // Now we check if merging can be made for these Gather and Cast nodes
if (!retVal || matchedNodesNum < 2) if (!retVal || matchedNodesNum < 2)
return retVal; return retVal;
else { else {
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1]; int nodeToMatch = matchedNodesIds[cast];
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch); const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
if (node->getType() == "Cast") { if (node->getType() == "Cast") {
int inpNodeId = matchedNodesIds[matchedNodesNum - 2]; int inpNodeId = matchedNodesIds[gather];
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId); const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
if (inpNode->getType() == "Gather") { if (inpNode->getType() == "Gather") {
int numNodes = net->getNumNodes(); int numNodes = net->getNumNodes();
@ -819,6 +783,9 @@ public:
} }
return retVal; return retVal;
} }
private:
int cast, gather;
}; };
/* Constant folding shape for Expand. /* Constant folding shape for Expand.
@ -838,12 +805,12 @@ public:
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int values = addNodeToMatch(""); int values = addNodeToMatch("");
int init = addNodeToMatch("ConstantOfShape", values); init = addNodeToMatch("ConstantOfShape", values);
int coeff = addNodeToMatch("Constant"); int coeff = addNodeToMatch("Constant");
int mul = addNodeToMatch("Mul", init, coeff); mul = addNodeToMatch("Mul", init, coeff);
int shape = addNodeToMatch("Constant"); int shape = addNodeToMatch("Constant");
int condition = addNodeToMatch("Equal", shape, mul); condition = addNodeToMatch("Equal", shape, mul);
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant")); where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
addNodeToMatch("Expand", input, where); addNodeToMatch("Expand", input, where);
setFusedNode("Expand", input, shape); setFusedNode("Expand", input, shape);
} }
@ -872,53 +839,28 @@ public:
return 0; return 0;
} }
static std::vector<int64_t> extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
Mat mat_constant;
if (initializer_id != -1) // initializer
{
mat_constant = onnx_net->getMatFromInitializer(initializer_id);
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
mat_constant = getMatFromTensor(constant_proto);
}
std::vector<int64_t> retvals{mat_constant.begin<int>(), mat_constant.end<int>()};
return retvals;
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE {
std::vector<int>& targetNodesIds) CV_OVERRIDE { if (Subgraph::match(net, nodeId, matchedNodesIds)) {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) {
int64_t value_ConstantOfShape; int64_t value_ConstantOfShape;
if (!extractValue(net, matchedNodesIds[0], value_ConstantOfShape)) { if (!extractValue(net, matchedNodesIds[init], value_ConstantOfShape)) {
return false; return false;
} }
std::vector<int64_t> input_ConstantOfShape = extractConstant(net, matchedNodesIds[0], 0); std::vector<int> input_ConstantOfShape = extractConstant(net, matchedNodesIds[init], 0);
if (input_ConstantOfShape.size() != static_cast<size_t>(1)) { if (input_ConstantOfShape.size() != static_cast<size_t>(1)) {
return false; return false;
} }
std::vector<int> B_Mul = extractConstant(net, matchedNodesIds[mul], 1);
auto B_Mul = extractConstant(net, matchedNodesIds[1], 1);
if (B_Mul.size() != static_cast<size_t>(1)) { if (B_Mul.size() != static_cast<size_t>(1)) {
return false; return false;
} }
auto A_Equal = extractConstant(net, matchedNodesIds[2], 0); std::vector<int> A_Equal = extractConstant(net, matchedNodesIds[condition], 0);
if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) { if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) {
return false; return false;
} }
auto Y_Where = extractConstant(net, matchedNodesIds[3], 2); std::vector<int> Y_Where = extractConstant(net, matchedNodesIds[where], 2);
if (Y_Where.size() != A_Equal.size()) { if (Y_Where.size() != A_Equal.size()) {
return false; return false;
} }
@ -969,6 +911,9 @@ public:
protected: protected:
std::vector<int64_t> shape; std::vector<int64_t> shape;
private:
int init, mul, condition, where;
}; };
class MishSubgraph : public Subgraph class MishSubgraph : public Subgraph
@ -979,7 +924,7 @@ public:
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int softplus = addNodeToMatch("Softplus", input); int softplus = addNodeToMatch("Softplus", input);
int tanh = addNodeToMatch("Tanh", softplus); int tanh = addNodeToMatch("Tanh", softplus);
addNodeToMatch("Mul", input, tanh); addNodeToMatch("Mul", tanh, input);
setFusedNode("Mish", input); setFusedNode("Mish", input);
} }
}; };
@ -999,20 +944,6 @@ public:
} }
}; };
class SoftplusSubgraph2: public Subgraph
{
public:
SoftplusSubgraph2()
{
int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", exp, addVal);
addNodeToMatch("Log", add);
setFusedNode("Softplus", input);
}
};
class MulCastSubgraph : public Subgraph class MulCastSubgraph : public Subgraph
{ {
public: public:
@ -1248,7 +1179,6 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>()); subgraphs.push_back(makePtr<ExpandSubgraph>());
subgraphs.push_back(makePtr<SoftplusSubgraph>()); subgraphs.push_back(makePtr<SoftplusSubgraph>());
subgraphs.push_back(makePtr<SoftplusSubgraph2>());
subgraphs.push_back(makePtr<MishSubgraph>()); subgraphs.push_back(makePtr<MishSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph4>()); subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>()); subgraphs.push_back(makePtr<NormalizeSubgraph5>());

@ -98,6 +98,14 @@ public:
net.mutable_node()->DeleteSubrange(idx, 1); net.mutable_node()->DeleteSubrange(idx, 1);
} }
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
{
return type == "Add" || type == "Sum" ||
type == "Mul" || type == "Prod" ||
type == "Max" || type == "Maximum" || type == "Minimum" ||
type == "Mean" || type == "SquaredDifference";
}
tensorflow::GraphDef& net; tensorflow::GraphDef& net;
}; };
@ -282,24 +290,26 @@ public:
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int relu = addNodeToMatch("Relu", input); int relu = addNodeToMatch("Relu", input);
int maxValue = addNodeToMatch("Const"); maxValueId = addNodeToMatch("Const");
int clipValue = addNodeToMatch("Const"); int clipValue = addNodeToMatch("Const");
int minimum = addNodeToMatch("Minimum", relu, maxValue); int minimum = addNodeToMatch("Minimum", relu, maxValueId);
addNodeToMatch("Maximum", minimum, clipValue); addNodeToMatch("Maximum", minimum, clipValue);
setFusedNode("Relu6", input); setFusedNode("Relu6", input);
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) if (!Subgraph::match(net, nodeId, matchedNodesIds))
return false; return false;
tensorflow::NodeDef* node = net->getNode(matchedNodesIds.front() + 1).dynamicCast<TFNodeWrapper>()->node; tensorflow::NodeDef* node = net->getNode(matchedNodesIds[maxValueId]).dynamicCast<TFNodeWrapper>()->node;
Mat maxValue = getTensorContent(node->attr().at("value").tensor()); Mat maxValue = getTensorContent(node->attr().at("value").tensor());
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6; return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
} }
private:
int maxValueId;
}; };
// Keras' reshape stores output shape in separate Const nodes by one value. // Keras' reshape stores output shape in separate Const nodes by one value.
@ -328,15 +338,14 @@ public:
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds, std::vector<int>& matchedNodesIds) CV_OVERRIDE
std::vector<int>& targetNodesIds) CV_OVERRIDE
{ {
Ptr<ImportNodeWrapper> node = net->getNode(nodeId); Ptr<ImportNodeWrapper> node = net->getNode(nodeId);
if (node->getNumInputs() == 0) if (node->getNumInputs() == 0)
return false; return false;
inpName = node->getInputName(0); inpName = node->getInputName(0);
return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); return Subgraph::match(net, nodeId, matchedNodesIds);
} }

Loading…
Cancel
Save