From e0ac0cfbe2426a478fad9a2c86dd1cb5209b19c3 Mon Sep 17 00:00:00 2001 From: "ashishiva3@gmail.com" Date: Thu, 19 Mar 2020 22:22:36 +0530 Subject: [PATCH] add fused batchNorm Upsample --- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 31 +++++++++++++++++++ modules/dnn/src/onnx/onnx_importer.cpp | 19 ++++++++++++ modules/dnn/test/test_onnx_importer.cpp | 19 ++++++++++++ 3 files changed, 69 insertions(+) diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index bf992feb2c..b5bb92e92a 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -382,6 +382,36 @@ public: } }; +class BatchNormalizationSubgraph : public Subgraph +{ +public: + BatchNormalizationSubgraph() + { + int input = addNodeToMatch(""); + int data1 = addNodeToMatch("Constant"); + int data2 = addNodeToMatch("Constant"); + int data3 = addNodeToMatch("Constant"); + int data4 = addNodeToMatch("Constant"); + int shape1 = addNodeToMatch("Constant"); + int reshape1 = addNodeToMatch("Reshape", data1, shape1); + int shape2 = addNodeToMatch("Constant"); + int reshape2 = addNodeToMatch("Reshape", data2, shape2); + int shape3 = addNodeToMatch("Constant"); + int reshape3 = addNodeToMatch("Reshape", data3, shape3); + int shape4 = addNodeToMatch("Constant"); + int reshape4 = addNodeToMatch("Reshape", data4, shape4); + int sqrtNode = addNodeToMatch("Sqrt", reshape3); + int A = addNodeToMatch("Constant"); + int divNode = addNodeToMatch("Div", A, sqrtNode); + int mul1 = addNodeToMatch("Mul", reshape1, divNode); + int mul2 = addNodeToMatch("Mul", reshape4, mul1); + int sub = addNodeToMatch("Sub", reshape2, mul2); + int mul3 = addNodeToMatch("Mul", input, mul1); + addNodeToMatch("Add", mul3, sub); + setFusedNode("BatchNormalization", input, data1, data2, data4 ,data3); + } +}; + void simplifySubgraphs(opencv_onnx::GraphProto& net) { std::vector > subgraphs; @@ -394,6 +424,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 2b0d846721..92fc3845c3 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -309,11 +309,30 @@ static void addConstant(const std::string& name, outShapes.insert(std::make_pair(name, shape(blob))); } +void addConstantNodesForInitializers(opencv_onnx::GraphProto& graph_proto) +{ + int num_initializers = graph_proto.initializer_size(); + for (int id = 0; id < num_initializers; id++) + { + opencv_onnx::TensorProto initializer = graph_proto.initializer(id); + opencv_onnx::NodeProto* constant_node = graph_proto.add_node(); + constant_node->set_op_type("Constant"); + constant_node->set_name(initializer.name()); + constant_node->add_output(initializer.name()); + opencv_onnx::AttributeProto* value = constant_node->add_attribute(); + opencv_onnx::TensorProto* tensor = initializer.New(); + tensor->CopyFrom(initializer); + releaseONNXTensor(initializer); + value->set_allocated_t(tensor); + } +} + void ONNXImporter::populateNet(Net dstNet) { CV_Assert(model_proto.has_graph()); opencv_onnx::GraphProto graph_proto = model_proto.graph(); + addConstantNodesForInitializers(graph_proto); simplifySubgraphs(graph_proto); std::map constBlobs = getGraphTensors(graph_proto); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 98b529b3f9..fd69f91a92 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -290,6 +290,15 @@ TEST_P(Test_ONNX_layers, BatchNormalization3D) testONNXModels("batch_norm_3d"); } +TEST_P(Test_ONNX_layers, BatchNormalizationUnfused) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH); + testONNXModels("frozenBatchNorm2d"); +} + TEST_P(Test_ONNX_layers, Transpose) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) @@ -372,6 +381,16 @@ TEST_P(Test_ONNX_layers, ResizeUnfused) testONNXModels("resize_bilinear_unfused_opset11_torch1.4"); } +TEST_P(Test_ONNX_layers, ResizeUnfusedTwoInputs) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH); + testONNXModels("upsample_unfused_two_inputs_opset9_torch1.4", npy, 0, 0, false, true, 2); + testONNXModels("upsample_unfused_two_inputs_opset11_torch1.4", npy, 0, 0, false, true, 2); +} + TEST_P(Test_ONNX_layers, MultyInputs) { testONNXModels("multy_inputs", npy, 0, 0, false, true, 2);