|
|
@ -1367,7 +1367,7 @@ Mat LayerEinsumImpl::batchwiseMatMul( |
|
|
|
|
|
|
|
|
|
|
|
// input1 should of size MxK
|
|
|
|
// input1 should of size MxK
|
|
|
|
// check if input1 needs reshape, if need reshape
|
|
|
|
// check if input1 needs reshape, if need reshape
|
|
|
|
if (input1.dims > 2 || input1.size[0] != M || input1.size[1] != K) |
|
|
|
if (input1.dims > 2 || input1.size[0] != M || (input1.dims > 1 && input1.size[1] != K) || input1.dims == 1) |
|
|
|
{ |
|
|
|
{ |
|
|
|
int shape[] = {M, K}; |
|
|
|
int shape[] = {M, K}; |
|
|
|
reshapedInput1 = input1.reshape(1, 2, shape); |
|
|
|
reshapedInput1 = input1.reshape(1, 2, shape); |
|
|
@ -1375,7 +1375,7 @@ Mat LayerEinsumImpl::batchwiseMatMul( |
|
|
|
|
|
|
|
|
|
|
|
// input2 should be of size KxN
|
|
|
|
// input2 should be of size KxN
|
|
|
|
// check if input2 needs reshape, if needs reshape
|
|
|
|
// check if input2 needs reshape, if needs reshape
|
|
|
|
if (input2.dims > 2 || input2.size[0] != K || input2.size[1] != N) |
|
|
|
if (input2.dims > 2 || input2.size[0] != K || (input2.dims > 1 && input2.size[1] != N) || input2.dims == 1) |
|
|
|
{ |
|
|
|
{ |
|
|
|
int shape2[] = {K, N}; |
|
|
|
int shape2[] = {K, N}; |
|
|
|
reshapedInput2 = input2.reshape(1, 2, shape2); |
|
|
|
reshapedInput2 = input2.reshape(1, 2, shape2); |
|
|
|