|
|
@ -154,16 +154,10 @@ private: |
|
|
|
int axis; |
|
|
|
int axis; |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
class NormalizeSubgraph1 : public Subgraph |
|
|
|
class NormalizeSubgraphBase : public Subgraph |
|
|
|
{ |
|
|
|
{ |
|
|
|
public: |
|
|
|
public: |
|
|
|
NormalizeSubgraph1() : axis(1) |
|
|
|
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {} |
|
|
|
{ |
|
|
|
|
|
|
|
input = addNodeToMatch(""); |
|
|
|
|
|
|
|
norm = addNodeToMatch("ReduceL2", input); |
|
|
|
|
|
|
|
addNodeToMatch("Div", input, norm); |
|
|
|
|
|
|
|
setFusedNode("Normalize", input); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
@ -171,7 +165,7 @@ public: |
|
|
|
{ |
|
|
|
{ |
|
|
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
|
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
|
|
{ |
|
|
|
{ |
|
|
|
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[0]); |
|
|
|
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]); |
|
|
|
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < node->attribute_size(); i++) |
|
|
|
for (int i = 0; i < node->attribute_size(); i++) |
|
|
@ -204,20 +198,51 @@ public: |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
protected: |
|
|
|
protected: |
|
|
|
int input, norm; |
|
|
|
int axis, normNodeOrder; |
|
|
|
int axis; |
|
|
|
|
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NormalizeSubgraph1 : public NormalizeSubgraphBase |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
public: |
|
|
|
|
|
|
|
NormalizeSubgraph1() |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
|
|
|
int norm = addNodeToMatch("ReduceL2", input); |
|
|
|
|
|
|
|
addNodeToMatch("Div", input, norm); |
|
|
|
|
|
|
|
setFusedNode("Normalize", input); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
class NormalizeSubgraph2 : public NormalizeSubgraph1 |
|
|
|
class NormalizeSubgraph2 : public NormalizeSubgraphBase |
|
|
|
{ |
|
|
|
{ |
|
|
|
public: |
|
|
|
public: |
|
|
|
NormalizeSubgraph2() : NormalizeSubgraph1() |
|
|
|
NormalizeSubgraph2() |
|
|
|
{ |
|
|
|
{ |
|
|
|
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
|
|
|
int norm = addNodeToMatch("ReduceL2", input); |
|
|
|
int clip = addNodeToMatch("Clip", norm); |
|
|
|
int clip = addNodeToMatch("Clip", norm); |
|
|
|
int shape = addNodeToMatch("Shape", input); |
|
|
|
int shape = addNodeToMatch("Shape", input); |
|
|
|
int expand = addNodeToMatch("Expand", clip, shape); |
|
|
|
int expand = addNodeToMatch("Expand", clip, shape); |
|
|
|
addNodeToMatch("Div", input, expand); |
|
|
|
addNodeToMatch("Div", input, expand); |
|
|
|
|
|
|
|
setFusedNode("Normalize", input); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NormalizeSubgraph3 : public NormalizeSubgraphBase |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
public: |
|
|
|
|
|
|
|
NormalizeSubgraph3() : NormalizeSubgraphBase(1) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
|
|
|
int power = addNodeToMatch("Constant"); |
|
|
|
|
|
|
|
int squared = addNodeToMatch("Pow", input, power); |
|
|
|
|
|
|
|
int sum = addNodeToMatch("ReduceSum", squared); |
|
|
|
|
|
|
|
int sqrtNode = addNodeToMatch("Sqrt", sum); |
|
|
|
|
|
|
|
int eps = addNodeToMatch("Constant"); |
|
|
|
|
|
|
|
int add = addNodeToMatch("Add", sqrtNode, eps); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
addNodeToMatch("Div", input, add); |
|
|
|
|
|
|
|
setFusedNode("Normalize", input); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
@ -368,6 +393,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
subgraphs.push_back(makePtr<SoftMaxSubgraph>()); |
|
|
|
subgraphs.push_back(makePtr<SoftMaxSubgraph>()); |
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph1>()); |
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph1>()); |
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph2>()); |
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph2>()); |
|
|
|
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph3>()); |
|
|
|
|
|
|
|
|
|
|
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); |
|
|
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); |
|
|
|
} |
|
|
|
} |
|
|
|