diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index cde1a78ffe..a1b60c52e8 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -314,6 +314,19 @@ public: } }; +class MishSubgraph : public Subgraph +{ +public: + MishSubgraph() + { + int input = addNodeToMatch(""); + int softplus = addNodeToMatch("Softplus", input); + int tanh = addNodeToMatch("Tanh", softplus); + addNodeToMatch("Mul", input, tanh); + setFusedNode("Mish", input); + } +}; + class MulCastSubgraph : public Subgraph { public: @@ -512,6 +525,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index f38ca6700f..40d9803b5c 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -660,6 +660,11 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6) testONNXModels("resize_opset11_torch1.6"); } +TEST_P(Test_ONNX_layers, Mish) +{ + testONNXModels("mish"); +} + TEST_P(Test_ONNX_layers, Conv1d) { testONNXModels("conv1d");