From 68eb54dc13e88f86e72157be98efe208024f960c Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Mon, 1 Feb 2021 12:38:33 +0300 Subject: [PATCH] Added ONNX NormalizeL2 subgraph --- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 36 +++++++++++++++++++ modules/dnn/test/test_onnx_importer.cpp | 5 +++ 2 files changed, 41 insertions(+) diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index a1b60c52e8..b81ccf106c 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -249,6 +249,40 @@ public: } }; +class NormalizeSubgraph4 : public NormalizeSubgraphBase +{ +public: + NormalizeSubgraph4() : NormalizeSubgraphBase(1) + { + int input = addNodeToMatch(""); + int mul = addNodeToMatch("Mul", input, input); + int sum = addNodeToMatch("ReduceSum", mul); + int eps = addNodeToMatch(""); + int max = addNodeToMatch("Max", sum, eps); + int sqrt = addNodeToMatch("Sqrt", max); + int reciprocal = addNodeToMatch("Reciprocal", sqrt); + addNodeToMatch("Mul", input, reciprocal); + setFusedNode("Normalize", input); + } +}; + +class NormalizeSubgraph5 : public NormalizeSubgraphBase +{ +public: + NormalizeSubgraph5() : NormalizeSubgraphBase(1) + { + int input = addNodeToMatch(""); + int mul = addNodeToMatch("Mul", input, input); + int sum = addNodeToMatch("ReduceSum", mul); + int clip = addNodeToMatch("Clip", sum); + int sqrt = addNodeToMatch("Sqrt", clip); + int one = addNodeToMatch("Constant"); + int div = addNodeToMatch("Div", one, sqrt); + addNodeToMatch("Mul", input, div); + setFusedNode("Normalize", input); + } +}; + class GatherCastSubgraph : public Subgraph { public: @@ -526,6 +560,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 40d9803b5c..488f809b75 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -403,6 +403,11 @@ TEST_P(Test_ONNX_layers, BatchNormalizationSubgraph) testONNXModels("batch_norm_subgraph"); } +TEST_P(Test_ONNX_layers, NormalizeFusionSubgraph) +{ + testONNXModels("normalize_fusion"); +} + TEST_P(Test_ONNX_layers, Transpose) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)