add another Mish graph simplifier.

pull/22311/head
Zihao Mu 2 years ago
parent fc3e393516
commit 3c5377ca1b
  1. 34
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 1
      modules/dnn/test/test_onnx_importer.cpp

@ -531,6 +531,38 @@ public:
} }
}; };
class MishSubgraph2 : public Subgraph
{
public:
MishSubgraph2()
{
int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", addVal, exp);
int log = addNodeToMatch("Log", add);
int tanh = addNodeToMatch("Tanh", log);
addNodeToMatch("Mul", input, tanh);
setFusedNode("Mish", input);
}
};
class MishSubgraph3 : public Subgraph
{
public:
MishSubgraph3()
{
int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", exp, addVal);
int log = addNodeToMatch("Log", add);
int tanh = addNodeToMatch("Tanh", log);
addNodeToMatch("Mul", input, tanh);
setFusedNode("Mish", input);
}
};
class MulCastSubgraph : public Subgraph class MulCastSubgraph : public Subgraph
{ {
public: public:
@ -735,6 +767,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>()); subgraphs.push_back(makePtr<ExpandSubgraph>());
subgraphs.push_back(makePtr<MishSubgraph>()); subgraphs.push_back(makePtr<MishSubgraph>());
subgraphs.push_back(makePtr<MishSubgraph2>());
subgraphs.push_back(makePtr<MishSubgraph3>());
subgraphs.push_back(makePtr<NormalizeSubgraph4>()); subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>()); subgraphs.push_back(makePtr<NormalizeSubgraph5>());

@ -1325,6 +1325,7 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6)
TEST_P(Test_ONNX_layers, Mish) TEST_P(Test_ONNX_layers, Mish)
{ {
testONNXModels("mish"); testONNXModels("mish");
testONNXModels("mish_no_softplus");
} }
TEST_P(Test_ONNX_layers, CalculatePads) TEST_P(Test_ONNX_layers, CalculatePads)

Loading…
Cancel
Save