Merge pull request #24483 from dkurt:dnn_fusion_commutative_ops

Commutative rules for DNN subgraphs fusion #24483

### Pull Request Readiness Checklist

related: https://github.com/opencv/opencv/pull/24463#issuecomment-1783033931

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24519/head
Dmitry Kurtaev 1 year ago committed by GitHub
parent 41c335e5a5
commit b7ec2ebb55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 68
      modules/dnn/src/graph_simplifier.cpp
  2. 8
      modules/dnn/src/graph_simplifier.hpp
  3. 304
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  4. 27
      modules/dnn/src/tensorflow/tf_graph_simplifier.cpp

@ -77,14 +77,14 @@ int Subgraph::getInputNodeId(const Ptr<ImportGraphWrapper>& net,
}
bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds)
std::vector<int>& matchedNodesIds)
{
matchedNodesIds.clear();
targetNodesIds.clear();
std::queue<int> nodesToMatch;
std::queue<int> targetNodes;
std::vector<std::pair<int, int> > matchings;
matchings.reserve(nodes.size());
nodesToMatch.push(nodeId);
targetNodes.push(nodes.size() - 1);
while (!nodesToMatch.empty())
@ -94,51 +94,63 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
nodesToMatch.pop();
targetNodes.pop();
if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) !=
matchedNodesIds.end())
if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair<int, int>& match){ return match.first == targetNodeId; }) !=
matchings.end())
continue;
// Empty placeholder matches with any input type
if (nodes[targetNodeId].empty()) {
matchings.push_back({targetNodeId, nodeToMatch});
continue;
}
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
if (node->getType() != nodes[targetNodeId])
return false;
continue;
std::vector<int>& inputNodes = inputs[targetNodeId];
if (inputNodes.size() != node->getNumInputs())
return false;
continue;
bool isCommutative = net->isCommutativeOp(node->getType());
for (int j = 0; j < inputNodes.size(); ++j)
{
if (nodes[inputNodes[j]].empty() || node->getInputName(j).empty()) // Unknown input node type.
// Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
if (node->getInputName(j).empty())
continue;
nodeId = getInputNodeId(net, node, j);
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
if (inpNode->getType() != "Const" && inpNode->getType() != "Constant")
if (isCommutative)
{
for (int i = 0; i < inputNodes.size(); ++i)
{
nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[i]);
}
}
else
{
nodesToMatch.push(nodeId);
targetNodes.push(inputNodes[j]);
}
else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant")
return false;
}
matchedNodesIds.push_back(nodeToMatch);
targetNodesIds.push_back(targetNodeId);
matchings.push_back({targetNodeId, nodeToMatch});
}
if (matchings.size() != nodes.size())
return false;
const int n = matchedNodesIds.size();
std::vector<std::pair<int, int> > elements(n);
for (int i = 0; i < n; ++i)
elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]);
std::sort(elements.begin(), elements.end());
for (int i = 0; i < n; ++i)
// Sort matched by pattern nodes order.
std::sort(matchings.begin(), matchings.end());
matchedNodesIds.resize(matchings.size());
for (int i = 0; i < matchings.size(); ++i)
{
matchedNodesIds[i] = elements[i].first;
targetNodesIds[i] = elements[i].second;
matchedNodesIds[i] = matchings[i].second;
}
return true;
}
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
const std::vector<int>& targetNodesIds)
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds)
{
// Extract names of input nodes.
std::vector<std::string> inputsNames(fusedNodeInputs.size());
@ -149,9 +161,9 @@ void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
{
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]);
std::vector<int>& inpIndices = inputs[targetNodesIds[j]];
std::vector<int>& inpIndices = inputs[j];
CV_Assert(node->getNumInputs() == inpIndices.size());
CV_Assert(inpIndices.empty() || node->getNumInputs() == inpIndices.size());
for (int k = 0; k < inpIndices.size(); ++k)
{
if (inpIndices[k] == fusedNodeInputs[i])
@ -187,15 +199,15 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
const std::vector<Ptr<Subgraph> >& patterns)
{
int numNodes = net->getNumNodes();
std::vector<int> matchedNodesIds, targetNodesIds;
std::vector<int> matchedNodesIds;
std::vector<int> nodesToRemove;
for (int j = 0; j < patterns.size(); ++j)
{
for (int i = 0; i < numNodes; ++i)
{
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
if (patterns[j]->match(net, i, matchedNodesIds))
{
patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
patterns[j]->replace(net, matchedNodesIds);
// Remove matched nodes except the last one.
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
}

@ -44,6 +44,8 @@ public:
virtual std::string getOutputName(int nodeId, int outId) const = 0;
virtual void removeNode(int idx) = 0;
virtual bool isCommutativeOp(const std::string& type) const = 0;
};
class Subgraph // Interface to match and replace subgraphs.
@ -75,12 +77,10 @@ public:
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds);
std::vector<int>& matchedNodesIds);
// Fuse matched subgraph.
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
const std::vector<int>& targetNodesIds);
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds);
virtual void finalize(const Ptr<ImportGraphWrapper>& net,
const Ptr<ImportNodeWrapper>& fusedNode,

@ -125,8 +125,13 @@ public:
virtual void removeNode(int idx) CV_OVERRIDE
{
CV_Assert(idx >= numInputs + numInitializers);
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
if (idx >= numInputs + numInitializers)
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
}
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
{
return type == "Add" || type == "Mul" || type == "Equal" || type == "Max";
}
private:
@ -134,6 +139,25 @@ private:
opencv_onnx::GraphProto& net;
};
static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
return onnx_net->getMatFromInitializer(initializer_id);
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = Subgraph::getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
return getMatFromTensor(constant_proto);
}
}
/* Fusion for Gelu.
Graph before fusion:
@ -151,54 +175,32 @@ public:
GeluSubGraph()
{
int input = addNodeToMatch("");
int div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ );
div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ );
int erf = addNodeToMatch("Erf", div);
int add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ );
add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ );
int mul = addNodeToMatch("Mul", input, add);
addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ;
mul2 = addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ;
setFusedNode("Gelu", input);
}
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
Mat constant_mat = getMatFromTensor(constant_proto);
return *constant_mat.ptr<float>();
}
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (Subgraph::match(net, nodeId, matchedNodesIds))
{
// Check Div[B=sqrt(2)]
float divisor = extractConstant(net, matchedNodesIds[0], 1);
float divisor = extractConstant(net, matchedNodesIds[div], 1).at<float>(0);
if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon())
return false;
// Check Add[B=1]
float add_const = extractConstant(net, matchedNodesIds[2], 1);
float add_const = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon())
return false;
// Check Mul[B=0.5]
float mul_const = extractConstant(net, matchedNodesIds[4], 1);
float mul_const = extractConstant(net, matchedNodesIds[mul2], 1).at<float>(0);
if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon())
return false;
@ -206,6 +208,9 @@ public:
}
return false;
}
private:
int div, add, mul2;
};
/* Fusion for GeluApproximation.
@ -229,61 +234,39 @@ public:
int input = addNodeToMatch("");
int mul0 = addNodeToMatch("Mul", input, input);
int mul1 = addNodeToMatch("Mul", input, mul0);
int mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1);
mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1);
int add0 = addNodeToMatch("Add", input, mul2);
int mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0);
mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0);
int tanh = addNodeToMatch("Tanh", mul3);
int add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh);
add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh);
int mul4 = addNodeToMatch("Mul", input, add1);
addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4);
mul5 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4);
setFusedNode("GeluApproximation", input);
}
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
Mat constant_mat = getMatFromTensor(constant_proto);
return *constant_mat.ptr<float>();
}
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (Subgraph::match(net, nodeId, matchedNodesIds))
{
// Check Mul[A=0.044714998453855515]
float coef = extractConstant(net, matchedNodesIds[2], 0);
float coef = extractConstant(net, matchedNodesIds[mul2], 0).at<float>(0);
if (coef - 0.044714998453855515 >= 1e-6)
return false;
// Check Mul[A=sqrt(2/pie)]
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0);
float sqrt_2_pie = extractConstant(net, matchedNodesIds[mul3], 0).at<float>(0);
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6)
return false;
// Check Add[A=1]
float add_const = extractConstant(net, matchedNodesIds[6], 0);
float add_const = extractConstant(net, matchedNodesIds[add1], 0).at<float>(0);
if (add_const - 1.f >= 1e-6)
return false;
// Check Mul[A=0.5]
float mul_const = extractConstant(net, matchedNodesIds[8], 0);
float mul_const = extractConstant(net, matchedNodesIds[mul5], 0).at<float>(0);
if (mul_const - 0.5f >= 1e-6)
return false;
@ -291,6 +274,9 @@ public:
}
return false;
}
private:
int mul2, mul3, add1, mul5;
};
/* Fusion for LayerNormalization.
@ -313,43 +299,22 @@ public:
LayerNormSubGraph() : axis(-1), epsilon(1e-5)
{
int input = addNodeToMatch("");
int mean = addNodeToMatch("ReduceMean", input);
mean = addNodeToMatch("ReduceMean", input);
int sub = addNodeToMatch("Sub", input, mean);
int pow = addNodeToMatch("Pow", sub, addNodeToMatch(""));
int mean1 = addNodeToMatch("ReduceMean", pow);
int add = addNodeToMatch("Add", mean1, addNodeToMatch(""));
pow = addNodeToMatch("Pow", sub, addNodeToMatch(""));
mean1 = addNodeToMatch("ReduceMean", pow);
add = addNodeToMatch("Add", mean1, addNodeToMatch(""));
int sqrt = addNodeToMatch("Sqrt", add);
int div = addNodeToMatch("Div", sub, sqrt);
int mul = addNodeToMatch("Mul", div, addNodeToMatch(""));
addNodeToMatch("Add", mul, addNodeToMatch(""));
mul = addNodeToMatch("Mul", div, addNodeToMatch(""));
bias = addNodeToMatch("Add", mul, addNodeToMatch(""));
setFusedNode("LayerNormalization", input);
}
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1) // initializer
{
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
Mat constant_mat = getMatFromTensor(constant_proto);
return *constant_mat.ptr<float>();
}
}
static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
{
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
@ -381,25 +346,24 @@ public:
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (Subgraph::match(net, nodeId, matchedNodesIds))
{
float pow_exp = extractConstant(net, matchedNodesIds[2], 1);
float pow_exp = extractConstant(net, matchedNodesIds[pow], 1).at<float>(0);
if (pow_exp - 2 > 1e-5) // not pow(2)
return false;
int axis_mean1 = extractAxis(net, matchedNodesIds[0]);
int axis_mean2 = extractAxis(net, matchedNodesIds[3]);
int axis_mean1 = extractAxis(net, matchedNodesIds[mean]);
int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]);
if (axis_mean1 != axis_mean2)
return false;
axis = axis_mean1;
epsilon = extractConstant(net, matchedNodesIds[4], 1);
epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
weight_name = getInputName(net, matchedNodesIds[7], 1);
bias_name = getInputName(net, matchedNodesIds[8], 1);
weight_name = getInputName(net, matchedNodesIds[mul], 1);
bias_name = getInputName(net, matchedNodesIds[bias], 1);
return true;
}
@ -429,6 +393,7 @@ protected:
float epsilon;
std::string weight_name;
std::string bias_name;
int pow, mean, mean1, add, mul, bias;
};
class SoftMaxSubgraphBase : public Subgraph
@ -437,10 +402,9 @@ public:
SoftMaxSubgraphBase() : axis(1), id(-1) {}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (Subgraph::match(net, nodeId, matchedNodesIds))
{
CV_Assert(id >= 0 && id < matchedNodesIds.size());
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
@ -485,7 +449,7 @@ public:
int inpExp = addNodeToMatch("Exp", input);
int sum = addNodeToMatch("ReduceSum", inpExp);
id = 1;
id = sum;
addNodeToMatch("Div", inpExp, sum);
setFusedNode("Softmax", input);
@ -498,7 +462,7 @@ public:
int input = addNodeToMatch("");
int reducemax = addNodeToMatch("ReduceMax", input);
id = 0;
id = reducemax;
int sub = addNodeToMatch("Sub", input, reducemax);
int exp = addNodeToMatch("Exp", sub);
@ -516,7 +480,7 @@ public:
int input = addNodeToMatch("");
int reducemax = addNodeToMatch("ReduceMax", input);
id = 0;
id = reducemax;
int sub_1 = addNodeToMatch("Sub", input, reducemax);
int exp = addNodeToMatch("Exp", sub_1);
@ -533,18 +497,17 @@ public:
HardSwishSubgraph()
{
int input = addNodeToMatch("");
int hardSigmoid = addNodeToMatch("HardSigmoid", input);
addNodeToMatch("Mul", input, hardSigmoid);
hardSigmoidId = addNodeToMatch("HardSigmoid", input);
addNodeToMatch("Mul", input, hardSigmoidId);
setFusedNode("HardSwish", input);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (Subgraph::match(net, nodeId, matchedNodesIds))
{
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[0]);
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[hardSigmoidId]);
opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast<ONNXNodeWrapper>()->node;
uint8_t matched = 0;
@ -561,6 +524,9 @@ public:
}
return false;
}
private:
int hardSigmoidId;
};
class CeluSubgraph : public Subgraph
@ -569,9 +535,9 @@ public:
CeluSubgraph() : alpha(1.f)
{
int input = addNodeToMatch("");
int div = addNodeToMatch("Div", input, addNodeToMatch(""));
int elu = addNodeToMatch("Elu", div);
addNodeToMatch("Mul", addNodeToMatch(""), elu);
div = addNodeToMatch("Div", input, addNodeToMatch(""));
elu = addNodeToMatch("Elu", div);
mul = addNodeToMatch("Mul", addNodeToMatch(""), elu);
setFusedNode("Celu", input);
}
@ -587,16 +553,15 @@ public:
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (Subgraph::match(net, nodeId, matchedNodesIds))
{
float alpha_div = extractAlpha(net, matchedNodesIds[0], 1);
float alpha_mul = extractAlpha(net, matchedNodesIds[2], 0);
float alpha_div = extractAlpha(net, matchedNodesIds[div], 1);
float alpha_mul = extractAlpha(net, matchedNodesIds[mul], 0);
float alpha_elu = 1.f;
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[1]);
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[elu]);
opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = 0; i < elu_node->attribute_size(); i++)
@ -625,18 +590,18 @@ public:
protected:
float alpha;
int div, mul, elu;
};
class NormalizeSubgraphBase : public Subgraph
{
public:
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
NormalizeSubgraphBase(int _normNodeOrder = 1) : axis(1), normNodeOrder(_normNodeOrder) {}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (Subgraph::match(net, nodeId, matchedNodesIds))
{
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
@ -725,7 +690,7 @@ public:
class NormalizeSubgraph3 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph3() : NormalizeSubgraphBase(1)
NormalizeSubgraph3() : NormalizeSubgraphBase(3)
{
int input = addNodeToMatch("");
int power = addNodeToMatch("Constant");
@ -743,7 +708,7 @@ public:
class NormalizeSubgraph4 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph4() : NormalizeSubgraphBase(1)
NormalizeSubgraph4() : NormalizeSubgraphBase(2)
{
int input = addNodeToMatch("");
int mul = addNodeToMatch("Mul", input, input);
@ -760,7 +725,7 @@ public:
class NormalizeSubgraph5 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph5() : NormalizeSubgraphBase(1)
NormalizeSubgraph5() : NormalizeSubgraphBase(2)
{
int input = addNodeToMatch("");
int mul = addNodeToMatch("Mul", input, input);
@ -781,25 +746,24 @@ public:
{
int input = addNodeToMatch("");
int index = addNodeToMatch("Constant");
int gather = addNodeToMatch("Gather", input, index);
addNodeToMatch("Cast", gather);
gather = addNodeToMatch("Gather", input, index);
cast = addNodeToMatch("Cast", gather);
setFusedNode("Gather", input, index);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds);
size_t matchedNodesNum = matchedNodesIds.size();
// Now we check if merging can be made for these Gather and Cast nodes
if (!retVal || matchedNodesNum < 2)
return retVal;
else {
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
int nodeToMatch = matchedNodesIds[cast];
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
if (node->getType() == "Cast") {
int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
int inpNodeId = matchedNodesIds[gather];
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
if (inpNode->getType() == "Gather") {
int numNodes = net->getNumNodes();
@ -819,6 +783,9 @@ public:
}
return retVal;
}
private:
int cast, gather;
};
/* Constant folding shape for Expand.
@ -838,12 +805,12 @@ public:
{
int input = addNodeToMatch("");
int values = addNodeToMatch("");
int init = addNodeToMatch("ConstantOfShape", values);
init = addNodeToMatch("ConstantOfShape", values);
int coeff = addNodeToMatch("Constant");
int mul = addNodeToMatch("Mul", init, coeff);
mul = addNodeToMatch("Mul", init, coeff);
int shape = addNodeToMatch("Constant");
int condition = addNodeToMatch("Equal", shape, mul);
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
condition = addNodeToMatch("Equal", shape, mul);
where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
addNodeToMatch("Expand", input, where);
setFusedNode("Expand", input, shape);
}
@ -872,53 +839,28 @@ public:
return 0;
}
static std::vector<int64_t> extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
Mat mat_constant;
if (initializer_id != -1) // initializer
{
mat_constant = onnx_net->getMatFromInitializer(initializer_id);
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
mat_constant = getMatFromTensor(constant_proto);
}
std::vector<int64_t> retvals{mat_constant.begin<int>(), mat_constant.end<int>()};
return retvals;
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) {
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
int64_t value_ConstantOfShape;
if (!extractValue(net, matchedNodesIds[0], value_ConstantOfShape)) {
if (!extractValue(net, matchedNodesIds[init], value_ConstantOfShape)) {
return false;
}
std::vector<int64_t> input_ConstantOfShape = extractConstant(net, matchedNodesIds[0], 0);
std::vector<int> input_ConstantOfShape = extractConstant(net, matchedNodesIds[init], 0);
if (input_ConstantOfShape.size() != static_cast<size_t>(1)) {
return false;
}
auto B_Mul = extractConstant(net, matchedNodesIds[1], 1);
std::vector<int> B_Mul = extractConstant(net, matchedNodesIds[mul], 1);
if (B_Mul.size() != static_cast<size_t>(1)) {
return false;
}
auto A_Equal = extractConstant(net, matchedNodesIds[2], 0);
std::vector<int> A_Equal = extractConstant(net, matchedNodesIds[condition], 0);
if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) {
return false;
}
auto Y_Where = extractConstant(net, matchedNodesIds[3], 2);
std::vector<int> Y_Where = extractConstant(net, matchedNodesIds[where], 2);
if (Y_Where.size() != A_Equal.size()) {
return false;
}
@ -969,6 +911,9 @@ public:
protected:
std::vector<int64_t> shape;
private:
int init, mul, condition, where;
};
class MishSubgraph : public Subgraph
@ -979,7 +924,7 @@ public:
int input = addNodeToMatch("");
int softplus = addNodeToMatch("Softplus", input);
int tanh = addNodeToMatch("Tanh", softplus);
addNodeToMatch("Mul", input, tanh);
addNodeToMatch("Mul", tanh, input);
setFusedNode("Mish", input);
}
};
@ -999,20 +944,6 @@ public:
}
};
class SoftplusSubgraph2: public Subgraph
{
public:
SoftplusSubgraph2()
{
int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", exp, addVal);
addNodeToMatch("Log", add);
setFusedNode("Softplus", input);
}
};
class MulCastSubgraph : public Subgraph
{
public:
@ -1248,7 +1179,6 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>());
subgraphs.push_back(makePtr<SoftplusSubgraph>());
subgraphs.push_back(makePtr<SoftplusSubgraph2>());
subgraphs.push_back(makePtr<MishSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>());

@ -98,6 +98,14 @@ public:
net.mutable_node()->DeleteSubrange(idx, 1);
}
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
{
return type == "Add" || type == "Sum" ||
type == "Mul" || type == "Prod" ||
type == "Max" || type == "Maximum" || type == "Minimum" ||
type == "Mean" || type == "SquaredDifference";
}
tensorflow::GraphDef& net;
};
@ -282,24 +290,26 @@ public:
{
int input = addNodeToMatch("");
int relu = addNodeToMatch("Relu", input);
int maxValue = addNodeToMatch("Const");
maxValueId = addNodeToMatch("Const");
int clipValue = addNodeToMatch("Const");
int minimum = addNodeToMatch("Minimum", relu, maxValue);
int minimum = addNodeToMatch("Minimum", relu, maxValueId);
addNodeToMatch("Maximum", minimum, clipValue);
setFusedNode("Relu6", input);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
if (!Subgraph::match(net, nodeId, matchedNodesIds))
return false;
tensorflow::NodeDef* node = net->getNode(matchedNodesIds.front() + 1).dynamicCast<TFNodeWrapper>()->node;
tensorflow::NodeDef* node = net->getNode(matchedNodesIds[maxValueId]).dynamicCast<TFNodeWrapper>()->node;
Mat maxValue = getTensorContent(node->attr().at("value").tensor());
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
}
private:
int maxValueId;
};
// Keras' reshape stores output shape in separate Const nodes by one value.
@ -328,15 +338,14 @@ public:
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
Ptr<ImportNodeWrapper> node = net->getNode(nodeId);
if (node->getNumInputs() == 0)
return false;
inpName = node->getInputName(0);
return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
return Subgraph::match(net, nodeId, matchedNodesIds);
}

Loading…
Cancel
Save