Merge pull request #22775 from WanliZhong:issue22713

pull/22855/head
Alexander Alekhin 2 years ago
commit a0a8d2160d
  1. 23
      modules/dnn/src/onnx/onnx_importer.cpp
  2. 7
      modules/dnn/test/test_onnx_importer.cpp

@ -2031,8 +2031,9 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_Assert(node_proto.input_size() == 2);
layerParams.type = "InnerProduct";
layerParams.set("bias_term", false);
@ -2044,8 +2045,24 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
{
Mat blob = getBlob(node_proto, 1);
secondInpDims = blob.dims;
layerParams.blobs.push_back(blob.t());
layerParams.set("num_output", layerParams.blobs[0].size[0]);
if (secondInpDims == 2)
{
layerParams.blobs.push_back(blob.t());
layerParams.set("num_output", layerParams.blobs[0].size[0]);
}
else
{
LayerParams constParams;
constParams.name = layerParams.name + "/const";
constParams.type = "Const";
constParams.blobs.push_back(blob);
opencv_onnx::NodeProto tmpProto;
tmpProto.add_output(constParams.name);
addLayer(constParams, tmpProto);
node_proto.set_input(1, constParams.name);
}
} else {
secondInpDims = outShapes[node_proto.input(1)].size();
}

@ -956,6 +956,13 @@ TEST_P(Test_ONNX_layers, MatMul)
testONNXModels("matmul_4d");
}
TEST_P(Test_ONNX_layers, MatMul_init)
{
testONNXModels("matmul_2d_init");
testONNXModels("matmul_3d_init");
testONNXModels("matmul_4d_init");
}
TEST_P(Test_ONNX_layers, MatMulAdd)
{
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_EQ(2022010000)

Loading…
Cancel
Save