Merge pull request #19330 from l-bat:lb/onnx_mish_subgraph

pull/19338/head^2
Alexander Alekhin 4 years ago
commit a122a53e72
  1. 14
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 5
      modules/dnn/test/test_onnx_importer.cpp

@ -314,6 +314,19 @@ public:
} }
}; };
class MishSubgraph : public Subgraph
{
public:
MishSubgraph()
{
int input = addNodeToMatch("");
int softplus = addNodeToMatch("Softplus", input);
int tanh = addNodeToMatch("Tanh", softplus);
addNodeToMatch("Mul", input, tanh);
setFusedNode("Mish", input);
}
};
class MulCastSubgraph : public Subgraph class MulCastSubgraph : public Subgraph
{ {
public: public:
@ -512,6 +525,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>()); subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>()); subgraphs.push_back(makePtr<ExpandSubgraph>());
subgraphs.push_back(makePtr<MishSubgraph>());
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
} }

@ -660,6 +660,11 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6)
testONNXModels("resize_opset11_torch1.6"); testONNXModels("resize_opset11_torch1.6");
} }
TEST_P(Test_ONNX_layers, Mish)
{
testONNXModels("mish");
}
TEST_P(Test_ONNX_layers, Conv1d) TEST_P(Test_ONNX_layers, Conv1d)
{ {
testONNXModels("conv1d"); testONNXModels("conv1d");

Loading…
Cancel
Save