@ -1299,7 +1299,6 @@ Mat LayerEinsumImpl::batchwiseMatMul(
const Mat & input2 ,
const Mat & input2 ,
const MatShape & input2ShapeOverride )
const MatShape & input2ShapeOverride )
{
{
// Sanity checks before the actual MatMul
// Sanity checks before the actual MatMul
CV_CheckType ( input1 . type ( ) , input2 . type ( ) , " Data types of the inputs must match for MatMul " ) ;
CV_CheckType ( input1 . type ( ) , input2 . type ( ) , " Data types of the inputs must match for MatMul " ) ;
CV_CheckEQ ( input1ShapeOverride . size ( ) , ( size_t ) 3 , " Only 1 batch dimension is allowed for MatMul " ) ;
CV_CheckEQ ( input1ShapeOverride . size ( ) , ( size_t ) 3 , " Only 1 batch dimension is allowed for MatMul " ) ;
@ -1312,61 +1311,22 @@ Mat LayerEinsumImpl::batchwiseMatMul(
int K = input1ShapeOverride [ 2 ] ;
int K = input1ShapeOverride [ 2 ] ;
int N = input2ShapeOverride [ 2 ] ;
int N = input2ShapeOverride [ 2 ] ;
std : : vector < Mat > output ;
Mat reshapedInput1 = input1 ;
Mat reshapedInput2 = input2 ;
Mat output ;
if ( batches > 1 )
if ( batches > 1 )
{
{
Mat reshapedInput1 = input1 ;
// create tmpout with type like input1
Mat reshapedInput2 = input2 ;
output = Mat ( { batches , M , N } , input1 . type ( ) ) ;
// input1 should of size MxK
// check if input1 needs reshape, if need reshape
if ( input1 . size [ 0 ] ! = M | | input1 . size [ 1 ] ! = K )
{
int shape [ ] = { batches , M , K } ;
reshapedInput1 = input1 . reshape ( 1 , 3 , shape ) ;
}
// input2 should be of size KxN
// check if input2 needs reshape, if needs reshape
if ( input2 . size [ 0 ] ! = K | | input2 . size [ 1 ] ! = N )
{
int shape [ ] = { batches , K , N } ;
reshapedInput2 = input2 . reshape ( 1 , 3 , shape ) ;
}
for ( size_t i = 0 ; i < batches ; i + + )
{
std : : vector < Range > ranges1 = { cv : : Range ( i , i + 1 ) } ;
for ( int j = 1 ; j < reshapedInput1 . dims ; j + + )
ranges1 . emplace_back ( cv : : Range : : all ( ) ) ;
Mat part1 = reshapedInput1 ( ranges1 ) ;
reshapedInput2 = reshapedInput2 . reshape ( 1 , input2ShapeOverride ) ;
int shape [ ] = { M , K } ;
reshapedInput1 = reshapedInput1 . reshape ( 1 , input1ShapeOverride ) ;
part1 = part1 . reshape ( 1 , sizeof ( shape ) / sizeof ( shape [ 0 ] ) , shape ) ;
std : : vector < Range > ranges2 = { cv : : Range ( i , i + 1 ) } ;
for ( int j = 1 ; j < reshapedInput2 . dims ; j + + )
ranges2 . emplace_back ( cv : : Range : : all ( ) ) ;
Mat part2 = reshapedInput2 ( ranges2 ) ;
int shape2 [ ] = { K , N } ;
part2 = part2 . reshape ( 1 , sizeof ( shape2 ) / sizeof ( shape2 [ 0 ] ) , shape2 ) ;
Mat tmp_output ( M , N , part1 . type ( ) ) ;
fastGemm ( false , false , 1.0 , part1 , part2 , 0.0 , tmp_output , opt ) ;
int newShape [ ] = { 1 , M , N } ;
tmp_output = tmp_output . reshape ( 1 , sizeof ( newShape ) / sizeof ( newShape [ 0 ] ) , newShape ) ;
output . emplace_back ( tmp_output ) ;
}
fastGemmBatch ( false , false , 1.0 , reshapedInput1 , reshapedInput2 , 0.0 , output , opt ) ;
} else {
} else {
Mat reshapedInput1 = input1 ;
Mat reshapedInput2 = input2 ;
// input1 should of size MxK
// input1 should of size MxK
// check if input1 needs reshape, if need reshape
if ( input1 . dims > 2 | | input1 . size [ 0 ] ! = M | | input1 . size [ 1 ] ! = K )
if ( input1 . dims > 2 | | input1 . size [ 0 ] ! = M | | input1 . size [ 1 ] ! = K )
{
{
int shape [ ] = { M , K } ;
int shape [ ] = { M , K } ;
@ -1374,30 +1334,18 @@ Mat LayerEinsumImpl::batchwiseMatMul(
}
}
// input2 should be of size KxN
// input2 should be of size KxN
// check if input2 needs reshape, if needs reshape
if ( input2 . dims > 2 | | input2 . size [ 0 ] ! = K | | input2 . size [ 1 ] ! = N )
if ( input2 . dims > 2 | | input2 . size [ 0 ] ! = K | | input2 . size [ 1 ] ! = N )
{
{
int shape2 [ ] = { K , N } ;
int shape2 [ ] = { K , N } ;
reshapedInput2 = input2 . reshape ( 1 , 2 , shape2 ) ;
reshapedInput2 = input2 . reshape ( 1 , 2 , shape2 ) ;
}
}
Mat tmp_output ( M , N , reshapedInput1 . type ( ) ) ;
output = Mat ( M , N , reshapedInput1 . type ( ) ) ;
fastGemm ( false , false , 1.0 , reshapedInput1 , reshapedInput2 , 0.0 , tmp_output , opt ) ;
fastGemm ( false , false , 1.0 , reshapedInput1 , reshapedInput2 , 0.0 , output , opt ) ;
int newShape [ ] = { 1 , M , N } ;
tmp_output = tmp_output . reshape ( 1 , sizeof ( newShape ) / sizeof ( newShape [ 0 ] ) , newShape ) ;
output . emplace_back ( tmp_output ) ;
}
int outputDim [ ] = { static_cast < int > ( output . size ( ) ) , M , N } ;
Mat output_buffer = Mat : : zeros ( 3 , outputDim , CV_32F ) ;
for ( size_t i = 0 ; i < output . size ( ) ; i + + ) {
output = output . reshape ( 1 , { 1 , M , N } ) ;
Mat output_slice = output_buffer . row ( i ) ;
output [ i ] . copyTo ( output_slice ) ;
}
}
return output_buffer ;
return output ;
} ;
} ;
Ptr < EinsumLayer > EinsumLayer : : create ( const LayerParams & params )
Ptr < EinsumLayer > EinsumLayer : : create ( const LayerParams & params )
{
{