let MatMul can work when both two inputs are const

pull/22873/head
zoom 2 years ago
parent bc6544c0bc
commit 5044af69d1
  1. 29
      modules/dnn/src/onnx/onnx_importer.cpp
  2. 2
      modules/dnn/test/test_onnx_importer.cpp

@ -2037,9 +2037,25 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
CV_Assert(node_proto.input_size() == 2);
layerParams.type = "InnerProduct";
layerParams.set("bias_term", false);
CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end());
int firstInpDims = outShapes[node_proto.input(0)].size();
int secondInpDims;
int firstInpDims, secondInpDims;
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat blob = getBlob(node_proto, 0);
firstInpDims = blob.dims;
LayerParams constParams;
constParams.name = layerParams.name + "/const_0";
constParams.type = "Const";
constParams.blobs.push_back(blob);
opencv_onnx::NodeProto tmpProto;
tmpProto.add_output(constParams.name);
addLayer(constParams, tmpProto);
node_proto.set_input(0, constParams.name);
}
else
firstInpDims = outShapes[node_proto.input(0)].size();
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
{
@ -2053,7 +2069,7 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
else
{
LayerParams constParams;
constParams.name = layerParams.name + "/const";
constParams.name = layerParams.name + "/const_1";
constParams.type = "Const";
constParams.blobs.push_back(blob);
@ -2063,9 +2079,10 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
node_proto.set_input(1, constParams.name);
}
} else {
secondInpDims = outShapes[node_proto.input(1)].size();
}
else
secondInpDims = outShapes[node_proto.input(1)].size();
layerParams.set("axis", firstInpDims - secondInpDims + 1);
addLayer(layerParams, node_proto);
}

@ -961,6 +961,8 @@ TEST_P(Test_ONNX_layers, MatMul_init)
testONNXModels("matmul_2d_init");
testONNXModels("matmul_3d_init");
testONNXModels("matmul_4d_init");
testONNXModels("matmul_init_2");
}
TEST_P(Test_ONNX_layers, MatMulAdd)

Loading…
Cancel
Save