diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index f6dc285fad..e65c7ac3e9 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -641,6 +641,17 @@ void ONNXImporter::populateNet(Net dstNet) { layerParams.type = "Scale"; layerParams.set("bias_term", true); + int axis = 1; + for (int i = 0; i < graph_proto.initializer_size(); i++) + { + opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i); + if (tensor_proto.name() == node_proto.input(const_blob_id)) + { + axis = inpShape.size() - tensor_proto.dims_size(); + break; + } + } + layerParams.set("axis", axis); blob = blob.reshape(1, 1); layerParams.blobs.push_back((isSub ? -1 : 1) * blob); } @@ -911,13 +922,20 @@ void ONNXImporter::populateNet(Net dstNet) CV_Assert(node_proto.input_size() == 2); layerParams.type = "InnerProduct"; layerParams.set("bias_term", false); + CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end()); + int firstInpDims = outShapes[node_proto.input(0)].size(); + int secondInpDims; if (constBlobs.find(node_proto.input(1)) != constBlobs.end()) { Mat blob = getBlob(node_proto, constBlobs, 1); + secondInpDims = blob.dims; layerParams.blobs.push_back(blob.t()); layerParams.set("num_output", layerParams.blobs[0].size[0]); + } else { + secondInpDims = outShapes[node_proto.input(1)].size(); } + layerParams.set("axis", firstInpDims - secondInpDims + 1); } else if (layer_type == "Mul" || layer_type == "Div") { diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 6a9e68dbc5..a317be71fb 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -404,6 +404,15 @@ TEST_P(Test_ONNX_layers, MatMul) testONNXModels("matmul_4d"); } +TEST_P(Test_ONNX_layers, MatMulAdd) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16) + applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16); + testONNXModels("matmul_add"); +} + TEST_P(Test_ONNX_layers, Expand) { testONNXModels("expand_batch");