diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index e4cbe02840..7a0532fcf4 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1759,15 +1759,15 @@ void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const openc addLayer(layerParams, node_proto); } +// A * B + C = Y, we require that the dimension of A is [m, k], and the dimension of B is [n, k]. +// And the dim of output Y is [m, n] void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { CV_Assert(node_proto.input_size() >= 2); layerParams.type = "InnerProduct"; Mat weights = getBlob(node_proto, 1); - int ind_num_out = 0; - if (layerParams.has("transB") && !layerParams.get("transB")) { + if (!layerParams.get("transB", 0)) { transpose(weights, weights); - ind_num_out = 1; } layerParams.blobs.push_back(weights); @@ -1789,7 +1789,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr addLayer(constParams, proto); } - layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]); + layerParams.set("num_output", layerParams.blobs[0].size[0]); layerParams.set("bias_term", node_proto.input_size() == 3); addLayer(layerParams, node_proto); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 60473ede58..56203cba56 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1389,6 +1389,12 @@ TEST_P(Test_ONNX_layers, DivConst) testONNXModels("div_const"); } +TEST_P(Test_ONNX_layers, Gemm) +{ + testONNXModels("gemm_no_transB"); + testONNXModels("gemm_transB_0"); +} + TEST_P(Test_ONNX_layers, OutputRegistration) { testONNXModels("output_registration", npy, 0, 0, false, true, 2);