Merge pull request #21162 from rogday:softmax_simplification

pull/21186/head^2
Alexander Alekhin 3 years ago
commit 35ff9af6ce
  1. 70
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp

@ -107,17 +107,10 @@ private:
opencv_onnx::GraphProto& net;
};
class SoftMaxSubgraph : public Subgraph
class SoftMaxSubgraphBase : public Subgraph
{
public:
SoftMaxSubgraph() : axis(1)
{
int input = addNodeToMatch("");
int inpExp = addNodeToMatch("Exp", input);
int sum = addNodeToMatch("ReduceSum", inpExp);
addNodeToMatch("Div", inpExp, sum);
setFusedNode("Softmax", input);
}
SoftMaxSubgraphBase() : axis(1), id(-1) {}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
@ -125,7 +118,8 @@ public:
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
{
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
CV_Assert(id >= 0 && id < matchedNodesIds.size());
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = 0; i < node->attribute_size(); i++)
@ -153,8 +147,60 @@ public:
attr->set_i(axis);
}
private:
protected:
int axis;
int id;
};
class SoftMaxSubgraph : public SoftMaxSubgraphBase
{
public:
SoftMaxSubgraph()
{
int input = addNodeToMatch("");
int inpExp = addNodeToMatch("Exp", input);
int sum = addNodeToMatch("ReduceSum", inpExp);
id = 1;
addNodeToMatch("Div", inpExp, sum);
setFusedNode("Softmax", input);
}
};
class SoftMaxSubgraph2 : public SoftMaxSubgraphBase {
public:
SoftMaxSubgraph2() {
int input = addNodeToMatch("");
int reducemax = addNodeToMatch("ReduceMax", input);
id = 0;
int sub = addNodeToMatch("Sub", input, reducemax);
int exp = addNodeToMatch("Exp", sub);
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
addNodeToMatch("Div", exp, reducesum);
setFusedNode("Softmax", input);
}
};
class LogSoftMaxSubgraph : public SoftMaxSubgraphBase
{
public:
LogSoftMaxSubgraph()
{
int input = addNodeToMatch("");
int reducemax = addNodeToMatch("ReduceMax", input);
id = 0;
int sub_1 = addNodeToMatch("Sub", input, reducemax);
int exp = addNodeToMatch("Exp", sub_1);
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
int log = addNodeToMatch("Log", reducesum);
addNodeToMatch("Sub", sub_1, log);
setFusedNode("LogSoftmax", input);
}
};
class NormalizeSubgraphBase : public Subgraph
@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<ResizeSubgraph1>());
subgraphs.push_back(makePtr<ResizeSubgraph2>());
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>());

Loading…
Cancel
Save