diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index c6e54d6a92..5aad1c135c 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -531,6 +531,38 @@ public: } }; +class MishSubgraph2 : public Subgraph +{ +public: + MishSubgraph2() + { + int input = addNodeToMatch(""); + int exp = addNodeToMatch("Exp", input); + int addVal = addNodeToMatch(""); + int add = addNodeToMatch("Add", addVal, exp); + int log = addNodeToMatch("Log", add); + int tanh = addNodeToMatch("Tanh", log); + addNodeToMatch("Mul", input, tanh); + setFusedNode("Mish", input); + } +}; + +class MishSubgraph3 : public Subgraph +{ +public: + MishSubgraph3() + { + int input = addNodeToMatch(""); + int exp = addNodeToMatch("Exp", input); + int addVal = addNodeToMatch(""); + int add = addNodeToMatch("Add", exp, addVal); + int log = addNodeToMatch("Log", add); + int tanh = addNodeToMatch("Tanh", log); + addNodeToMatch("Mul", input, tanh); + setFusedNode("Mish", input); + } +}; + class MulCastSubgraph : public Subgraph { public: @@ -735,6 +767,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 578e0442b2..39c635a095 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1325,6 +1325,7 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6) TEST_P(Test_ONNX_layers, Mish) { testONNXModels("mish"); + testONNXModels("mish_no_softplus"); } TEST_P(Test_ONNX_layers, CalculatePads)