Merge pull request #24906 from Abdurrahheem:ash/fix_einsum_inner

Einsum Layer Inner Product Issue Solution
pull/24912/head
Alexander Smorkalov 11 months ago committed by GitHub
commit d6424233f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      modules/dnn/src/layers/einsum_layer.cpp
  2. 2
      modules/dnn/test/test_onnx_importer.cpp

@ -1367,7 +1367,7 @@ Mat LayerEinsumImpl::batchwiseMatMul(
// input1 should of size MxK // input1 should of size MxK
// check if input1 needs reshape, if need reshape // check if input1 needs reshape, if need reshape
if (input1.dims > 2 || input1.size[0] != M || input1.size[1] != K) if (input1.dims > 2 || input1.size[0] != M || (input1.dims > 1 && input1.size[1] != K) || input1.dims == 1)
{ {
int shape[] = {M, K}; int shape[] = {M, K};
reshapedInput1 = input1.reshape(1, 2, shape); reshapedInput1 = input1.reshape(1, 2, shape);
@ -1375,7 +1375,7 @@ Mat LayerEinsumImpl::batchwiseMatMul(
// input2 should be of size KxN // input2 should be of size KxN
// check if input2 needs reshape, if needs reshape // check if input2 needs reshape, if needs reshape
if (input2.dims > 2 || input2.size[0] != K || input2.size[1] != N) if (input2.dims > 2 || input2.size[0] != K || (input2.dims > 1 && input2.size[1] != N) || input2.dims == 1)
{ {
int shape2[] = {K, N}; int shape2[] = {K, N};
reshapedInput2 = input2.reshape(1, 2, shape2); reshapedInput2 = input2.reshape(1, 2, shape2);

@ -1496,7 +1496,7 @@ TEST_P(Test_ONNX_layers, Einsum_5D)
} }
// https://github.com/opencv/opencv/issues/24883 // https://github.com/opencv/opencv/issues/24883
TEST_P(Test_ONNX_layers, DISABLED_Einsum_InnerProduct) TEST_P(Test_ONNX_layers, Einsum_InnerProduct)
{ {
testONNXModels("einsum_inner", npy, 0, 0, false, false, 2); testONNXModels("einsum_inner", npy, 0, 0, false, false, 2);
} }

Loading…
Cancel
Save