diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index a1b60c52e8..b81ccf106c 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -249,6 +249,40 @@ public: } }; +class NormalizeSubgraph4 : public NormalizeSubgraphBase +{ +public: + NormalizeSubgraph4() : NormalizeSubgraphBase(1) + { + int input = addNodeToMatch(""); + int mul = addNodeToMatch("Mul", input, input); + int sum = addNodeToMatch("ReduceSum", mul); + int eps = addNodeToMatch(""); + int max = addNodeToMatch("Max", sum, eps); + int sqrt = addNodeToMatch("Sqrt", max); + int reciprocal = addNodeToMatch("Reciprocal", sqrt); + addNodeToMatch("Mul", input, reciprocal); + setFusedNode("Normalize", input); + } +}; + +class NormalizeSubgraph5 : public NormalizeSubgraphBase +{ +public: + NormalizeSubgraph5() : NormalizeSubgraphBase(1) + { + int input = addNodeToMatch(""); + int mul = addNodeToMatch("Mul", input, input); + int sum = addNodeToMatch("ReduceSum", mul); + int clip = addNodeToMatch("Clip", sum); + int sqrt = addNodeToMatch("Sqrt", clip); + int one = addNodeToMatch("Constant"); + int div = addNodeToMatch("Div", one, sqrt); + addNodeToMatch("Mul", input, div); + setFusedNode("Normalize", input); + } +}; + class GatherCastSubgraph : public Subgraph { public: @@ -526,6 +560,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index afaa4726cc..9a1f28cdea 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -403,6 +403,11 @@ TEST_P(Test_ONNX_layers, BatchNormalizationSubgraph) testONNXModels("batch_norm_subgraph"); } +TEST_P(Test_ONNX_layers, NormalizeFusionSubgraph) +{ + testONNXModels("normalize_fusion"); +} + TEST_P(Test_ONNX_layers, Transpose) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)