|
|
|
@ -107,17 +107,10 @@ private: |
|
|
|
|
opencv_onnx::GraphProto& net; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class SoftMaxSubgraph : public Subgraph |
|
|
|
|
class SoftMaxSubgraphBase : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
SoftMaxSubgraph() : axis(1) |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int inpExp = addNodeToMatch("Exp", input); |
|
|
|
|
int sum = addNodeToMatch("ReduceSum", inpExp); |
|
|
|
|
addNodeToMatch("Div", inpExp, sum); |
|
|
|
|
setFusedNode("Softmax", input); |
|
|
|
|
} |
|
|
|
|
SoftMaxSubgraphBase() : axis(1), id(-1) {} |
|
|
|
|
|
|
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
|
@ -125,7 +118,8 @@ public: |
|
|
|
|
{ |
|
|
|
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
|
|
|
{ |
|
|
|
|
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]); |
|
|
|
|
CV_Assert(id >= 0 && id < matchedNodesIds.size()); |
|
|
|
|
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]); |
|
|
|
|
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
|
|
|
|
|
for (int i = 0; i < node->attribute_size(); i++) |
|
|
|
@ -153,8 +147,60 @@ public: |
|
|
|
|
attr->set_i(axis); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
protected: |
|
|
|
|
int axis; |
|
|
|
|
int id; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class SoftMaxSubgraph : public SoftMaxSubgraphBase |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
SoftMaxSubgraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int inpExp = addNodeToMatch("Exp", input); |
|
|
|
|
|
|
|
|
|
int sum = addNodeToMatch("ReduceSum", inpExp); |
|
|
|
|
id = 1; |
|
|
|
|
|
|
|
|
|
addNodeToMatch("Div", inpExp, sum); |
|
|
|
|
setFusedNode("Softmax", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class SoftMaxSubgraph2 : public SoftMaxSubgraphBase { |
|
|
|
|
public: |
|
|
|
|
SoftMaxSubgraph2() { |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
|
|
|
|
|
int reducemax = addNodeToMatch("ReduceMax", input); |
|
|
|
|
id = 0; |
|
|
|
|
|
|
|
|
|
int sub = addNodeToMatch("Sub", input, reducemax); |
|
|
|
|
int exp = addNodeToMatch("Exp", sub); |
|
|
|
|
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch("")); |
|
|
|
|
addNodeToMatch("Div", exp, reducesum); |
|
|
|
|
setFusedNode("Softmax", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class LogSoftMaxSubgraph : public SoftMaxSubgraphBase |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
LogSoftMaxSubgraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
|
|
|
|
|
int reducemax = addNodeToMatch("ReduceMax", input); |
|
|
|
|
id = 0; |
|
|
|
|
|
|
|
|
|
int sub_1 = addNodeToMatch("Sub", input, reducemax); |
|
|
|
|
int exp = addNodeToMatch("Exp", sub_1); |
|
|
|
|
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch("")); |
|
|
|
|
int log = addNodeToMatch("Log", reducesum); |
|
|
|
|
addNodeToMatch("Sub", sub_1, log); |
|
|
|
|
setFusedNode("LogSoftmax", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class NormalizeSubgraphBase : public Subgraph |
|
|
|
@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
subgraphs.push_back(makePtr<ResizeSubgraph1>()); |
|
|
|
|
subgraphs.push_back(makePtr<ResizeSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<SoftMaxSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<SoftMaxSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph1>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>()); |
|
|
|
|