From 9ed1332355a300b66427aa229a58dc4a379d931a Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Wed, 4 Mar 2020 11:27:10 +0300 Subject: [PATCH] Merge pull request #16722 from l-bat:reshape_opset_11 * Supported Div op for constants * Added Mul test --- modules/dnn/src/onnx/onnx_importer.cpp | 69 ++++++++++++++----------- modules/dnn/test/test_onnx_importer.cpp | 2 + 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 6f3ac0409d..3d7e33a37f 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -465,31 +465,6 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.blobs.push_back(-1.0f * blob.reshape(1, 1)); } } - else if (layer_type == "Div") - { - if (constBlobs.find(node_proto.input(1)) == constBlobs.end()) - { - layerParams.type = "Eltwise"; - layerParams.set("operation", "div"); - } - else - { - Mat blob = getBlob(node_proto, constBlobs, 1); - CV_Assert_N(blob.type() == CV_32F, blob.total()); - if (blob.total() == 1) - { - layerParams.set("scale", 1.0f / blob.at(0)); - layerParams.type = "Power"; - } - else - { - layerParams.type = "Scale"; - divide(1.0, blob, blob); - layerParams.blobs.push_back(blob); - layerParams.set("bias_term", false); - } - } - } else if (layer_type == "Neg") { layerParams.type = "Power"; @@ -638,24 +613,58 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.set("bias_term", false); layerParams.set("num_output", layerParams.blobs[0].size[0]); } - else if (layer_type == "Mul") + else if (layer_type == "Mul" || layer_type == "Div") { CV_Assert(node_proto.input_size() == 2); - if (layer_id.find(node_proto.input(1)) == layer_id.end()) { - Mat blob = getBlob(node_proto, constBlobs, 1); + + bool isDiv = layer_type == "Div"; + int constId = -1; + bool haveVariables = false; + for (int i = 0; i < 2; ++i) + { + if (constBlobs.find(node_proto.input(i)) != constBlobs.end()) + constId = i; + else + haveVariables = true; + } + if (constId != -1 && haveVariables) + { + Mat blob = getBlob(node_proto, constBlobs, constId); blob = blob.reshape(1, 1); if (blob.total() == 1) { - layerParams.set("scale", blob.at(0)); + float coeff = isDiv ? 1.0 / blob.at(0) : blob.at(0); + layerParams.set("scale", coeff); layerParams.type = "Power"; } else { + if (isDiv) + divide(1.0, blob, blob); layerParams.blobs.push_back(blob); layerParams.type = "Scale"; } } else { layerParams.type = "Eltwise"; - layerParams.set("operation", "prod"); + layerParams.set("operation", isDiv ? "div" : "prod"); + } + + if (!haveVariables) + { + Mat inp0 = getBlob(node_proto, constBlobs, 0); + Mat inp1 = getBlob(node_proto, constBlobs, 1); + if (inp0.size != inp1.size) + CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes"); + + Mat out; + if (isDiv) + divide(inp0, inp1, out); + else + multiply(inp0, inp1, out); + + out = out.reshape(1, inp0.dims, inp0.size); + out.dims = inp0.dims; // to workaround dims == 1 + constBlobs.insert(std::make_pair(layerParams.name, out)); + continue; } } else if (layer_type == "Conv") diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index bb7cba1180..2838a72ea7 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -382,6 +382,8 @@ TEST_P(Test_ONNX_layers, DynamicReshape) if (target == DNN_TARGET_OPENCL) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); } testONNXModels("dynamic_reshape"); + testONNXModels("dynamic_reshape_opset_11"); + testONNXModels("flatten_by_prod"); } TEST_P(Test_ONNX_layers, Reshape)