@ -28,7 +28,84 @@ inline void SwapRowCols(const Mat &A, int &rows, int &cols, bool isTrans)
cols = ( isTrans ) ? A . rows : A . cols ;
}
class GEMMInvoker : public ParallelLoopBody
{
public :
GEMMInvoker ( const Mat * _a , const Mat * _b , double _alpha , Mat * _c , double _beta )
{
a = _a ;
b = _b ;
c = _c ;
alpha = _alpha ;
beta = _beta ;
}
void operator ( ) ( const Range & range ) const
{
int mmax = a - > rows ;
int nmax = range . end - range . start ;
int kmax = a - > cols ;
int m , n , k ;
AutoBuffer < float > buf ( nmax ) ;
float * ptr = buf ;
if ( mmax % 2 ! = 0 )
memset ( ptr , 0 , nmax * sizeof ( ptr [ 0 ] ) ) ;
for ( m = 0 ; m < mmax ; m + = 2 )
{
float * dst0 = c - > ptr < float > ( m ) + range . start ;
float * dst1 = m + 1 < mmax ? c - > ptr < float > ( m + 1 ) + range . start : ptr ;
const float * aptr0 = a - > ptr < float > ( m ) ;
const float * aptr1 = m + 1 < mmax ? a - > ptr < float > ( m + 1 ) : aptr0 ;
if ( beta ! = 1 )
{
if ( beta = = 0 )
for ( n = 0 ; n < nmax ; n + + )
{
dst0 [ n ] = 0.f ;
dst1 [ n ] = 0.f ;
}
else
for ( n = 0 ; n < nmax ; n + + )
{
dst0 [ n ] * = ( float ) beta ;
dst1 [ n ] * = ( float ) beta ;
}
}
for ( k = 0 ; k < kmax ; k + + )
{
float alpha0 = ( float ) ( alpha * aptr0 [ k ] ) ;
float alpha1 = ( float ) ( alpha * aptr1 [ k ] ) ;
const float * bptr = b - > ptr < float > ( k ) + range . start ;
for ( n = 0 ; n < nmax ; n + + )
{
float d0 = dst0 [ n ] + alpha0 * bptr [ n ] ;
float d1 = dst1 [ n ] + alpha1 * bptr [ n ] ;
dst0 [ n ] = d0 ;
dst1 [ n ] = d1 ;
}
}
}
}
const Mat * a , * b ;
Mat * c ;
double alpha , beta ;
} ;
void gemmCPU ( const Mat & A , const Mat & B , double alpha , Mat & C , double beta , int flags /*= 0*/ )
{
if ( C . type ( ) = = CV_32F & & flags = = 0 )
{
GEMMInvoker invoker ( & A , & B , alpha , & C , beta ) ;
double granularity = 10000000. / ( ( double ) A . rows * A . cols ) ;
parallel_for_ ( Range ( 0 , B . cols ) , invoker , granularity ) ;
}
else
{
# if HAVE_CBLAS
bool transA = static_cast < bool > ( flags & GEMM_1_T ) ;
@ -71,6 +148,7 @@ void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int
cv : : gemm ( A , B , alpha , C , beta , C , flags ) ;
# endif
}
}
int getBlasThreads ( )
{