diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index b5bb92e92a..61ef8b7da6 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -61,27 +61,28 @@ public: ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net) { numInputs = net.input_size(); + numInitializers = net.initializer_size(); } virtual Ptr<ImportNodeWrapper> getNode(int idx) const CV_OVERRIDE { opencv_onnx::NodeProto* node = 0; - if (idx >= numInputs) - node = net.mutable_node(idx - numInputs); + if (idx >= numInputs + numInitializers) + node = net.mutable_node(idx - numInputs - numInitializers); return makePtr<ONNXNodeWrapper>(node); } virtual int getNumNodes() const CV_OVERRIDE { - return numInputs + net.node_size(); + return numInputs + numInitializers + net.node_size(); } virtual int getNumOutputs(int nodeId) const CV_OVERRIDE { - if (nodeId < numInputs) + if (nodeId < numInputs + numInitializers) return 1; else - return net.node(nodeId - numInputs).output_size(); + return net.node(nodeId - numInputs - numInitializers).output_size(); } virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE @@ -89,18 +90,20 @@ public: CV_Assert(outId < getNumOutputs(nodeId)); if (nodeId < numInputs) return net.input(nodeId).name(); + else if (nodeId < numInputs + numInitializers) + return net.initializer(nodeId - numInputs).name(); else - return net.node(nodeId - numInputs).output(outId); + return net.node(nodeId - numInputs - numInitializers).output(outId); } virtual void removeNode(int idx) CV_OVERRIDE { - CV_Assert(idx >= numInputs); - net.mutable_node()->DeleteSubrange(idx - numInputs, 1); + CV_Assert(idx >= numInputs + numInitializers); + net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1); } private: - int numInputs; + int numInputs, numInitializers; opencv_onnx::GraphProto& net; }; @@ -382,33 +385,63 @@ public: } }; -class BatchNormalizationSubgraph : public Subgraph +class BatchNormalizationSubgraphBase : public Subgraph { public: - BatchNormalizationSubgraph() + BatchNormalizationSubgraphBase() { - int input = addNodeToMatch(""); - int data1 = addNodeToMatch("Constant"); - int data2 = addNodeToMatch("Constant"); - int data3 = addNodeToMatch("Constant"); - int data4 = addNodeToMatch("Constant"); - int shape1 = addNodeToMatch("Constant"); - int reshape1 = addNodeToMatch("Reshape", data1, shape1); - int shape2 = addNodeToMatch("Constant"); - int reshape2 = addNodeToMatch("Reshape", data2, shape2); + input = addNodeToMatch(""); + var = addNodeToMatch(""); + mean = addNodeToMatch(""); + weight = addNodeToMatch(""); + bias = addNodeToMatch(""); + A = addNodeToMatch(""); + shape1 = addNodeToMatch(""); + shape2 = addNodeToMatch(""); + } +protected: + int input, var, mean, weight, bias, A, shape1, shape2; +}; + +class BatchNormalizationSubgraph1 : public BatchNormalizationSubgraphBase +{ +public: + BatchNormalizationSubgraph1() + { + int reshape1 = addNodeToMatch("Reshape", weight, shape1); + int reshape2 = addNodeToMatch("Reshape", bias, shape2); int shape3 = addNodeToMatch("Constant"); - int reshape3 = addNodeToMatch("Reshape", data3, shape3); + int reshape3 = addNodeToMatch("Reshape", var, shape3); int shape4 = addNodeToMatch("Constant"); - int reshape4 = addNodeToMatch("Reshape", data4, shape4); + int reshape4 = addNodeToMatch("Reshape", mean, shape4); int sqrtNode = addNodeToMatch("Sqrt", reshape3); - int A = addNodeToMatch("Constant"); int divNode = addNodeToMatch("Div", A, sqrtNode); int mul1 = addNodeToMatch("Mul", reshape1, divNode); int mul2 = addNodeToMatch("Mul", reshape4, mul1); int sub = addNodeToMatch("Sub", reshape2, mul2); int mul3 = addNodeToMatch("Mul", input, mul1); addNodeToMatch("Add", mul3, sub); - setFusedNode("BatchNormalization", input, data1, data2, data4 ,data3); + setFusedNode("BatchNormalization", input, weight, bias, mean, var); + } +}; + +class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase +{ +public: + BatchNormalizationSubgraph2() + { + int sqrtNode = addNodeToMatch("Sqrt", var); + int divNode = addNodeToMatch("Div", A, sqrtNode); + int mul1 = addNodeToMatch("Mul", weight, divNode); + int reshape2 = addNodeToMatch("Reshape", mul1, shape2); + + int mulMean = addNodeToMatch("Mul", mean, mul1); + int sub = addNodeToMatch("Sub", bias, mulMean); + int reshape1 = addNodeToMatch("Reshape", sub, shape1); + + int mulInput = addNodeToMatch("Mul", input, reshape2); + addNodeToMatch("Add", mulInput, reshape1); + setFusedNode("BatchNormalization", input, weight, bias, mean, var); } }; @@ -424,7 +457,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr<NormalizeSubgraph1>()); subgraphs.push_back(makePtr<NormalizeSubgraph2>()); subgraphs.push_back(makePtr<NormalizeSubgraph3>()); - subgraphs.push_back(makePtr<BatchNormalizationSubgraph>()); + subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>()); + subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 92fc3845c3..2b0d846721 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -309,30 +309,11 @@ static void addConstant(const std::string& name, outShapes.insert(std::make_pair(name, shape(blob))); } -void addConstantNodesForInitializers(opencv_onnx::GraphProto& graph_proto) -{ - int num_initializers = graph_proto.initializer_size(); - for (int id = 0; id < num_initializers; id++) - { - opencv_onnx::TensorProto initializer = graph_proto.initializer(id); - opencv_onnx::NodeProto* constant_node = graph_proto.add_node(); - constant_node->set_op_type("Constant"); - constant_node->set_name(initializer.name()); - constant_node->add_output(initializer.name()); - opencv_onnx::AttributeProto* value = constant_node->add_attribute(); - opencv_onnx::TensorProto* tensor = initializer.New(); - tensor->CopyFrom(initializer); - releaseONNXTensor(initializer); - value->set_allocated_t(tensor); - } -} - void ONNXImporter::populateNet(Net dstNet) { CV_Assert(model_proto.has_graph()); opencv_onnx::GraphProto graph_proto = model_proto.graph(); - addConstantNodesForInitializers(graph_proto); simplifySubgraphs(graph_proto); std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index e0a4d4f665..cfffc9629a 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -306,6 +306,13 @@ TEST_P(Test_ONNX_layers, BatchNormalizationUnfused) testONNXModels("frozenBatchNorm2d"); } +TEST_P(Test_ONNX_layers, BatchNormalizationSubgraph) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH); + testONNXModels("batch_norm_subgraph"); +} + TEST_P(Test_ONNX_layers, Transpose) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)