Merge pull request #3108 from LeszekSwirski:fix-gemm-buf-allocate-2.4

pull/3031/head
Vadim Pisarevsky 11 years ago
commit 7409f21e9f
  1. 13
      modules/core/src/matmul.cpp

@ -1013,6 +1013,7 @@ void cv::gemm( InputArray matA, InputArray matB, double alpha,
GEMMBlockMulFunc blockMulFunc; GEMMBlockMulFunc blockMulFunc;
GEMMStoreFunc storeFunc; GEMMStoreFunc storeFunc;
Mat *matD = &D, tmat; Mat *matD = &D, tmat;
int tmat_size = 0;
const uchar* Cdata = C.data; const uchar* Cdata = C.data;
size_t Cstep = C.data ? (size_t)C.step : 0; size_t Cstep = C.data ? (size_t)C.step : 0;
AutoBuffer<uchar> buf; AutoBuffer<uchar> buf;
@ -1045,8 +1046,8 @@ void cv::gemm( InputArray matA, InputArray matB, double alpha,
if( D.data == A.data || D.data == B.data ) if( D.data == A.data || D.data == B.data )
{ {
buf.allocate(d_size.width*d_size.height*CV_ELEM_SIZE(type)); tmat_size = d_size.width*d_size.height*CV_ELEM_SIZE(type);
tmat = Mat(d_size.height, d_size.width, type, (uchar*)buf ); // Allocate tmat later, once the size of buf is known
matD = &tmat; matD = &tmat;
} }
@ -1123,6 +1124,10 @@ void cv::gemm( InputArray matA, InputArray matB, double alpha,
(d_size.width <= block_lin_size && (d_size.width <= block_lin_size &&
d_size.height <= block_lin_size && len <= block_lin_size) ) d_size.height <= block_lin_size && len <= block_lin_size) )
{ {
if( tmat_size > 0 ) {
buf.allocate(tmat_size);
tmat = Mat(d_size.height, d_size.width, type, (uchar*)buf );
}
singleMulFunc( A.data, A.step, B.data, b_step, Cdata, Cstep, singleMulFunc( A.data, A.step, B.data, b_step, Cdata, Cstep,
matD->data, matD->step, a_size, d_size, alpha, beta, flags ); matD->data, matD->step, a_size, d_size, alpha, beta, flags );
} }
@ -1182,12 +1187,14 @@ void cv::gemm( InputArray matA, InputArray matB, double alpha,
flags &= ~GEMM_1_T; flags &= ~GEMM_1_T;
} }
buf.allocate(a_buf_size + b_buf_size + d_buf_size); buf.allocate(d_buf_size + b_buf_size + a_buf_size + tmat_size);
d_buf = (uchar*)buf; d_buf = (uchar*)buf;
b_buf = d_buf + d_buf_size; b_buf = d_buf + d_buf_size;
if( is_a_t ) if( is_a_t )
a_buf = b_buf + b_buf_size; a_buf = b_buf + b_buf_size;
if( tmat_size > 0 )
tmat = Mat(d_size.height, d_size.width, type, b_buf + b_buf_size + a_buf_size );
for( i = 0; i < d_size.height; i += di ) for( i = 0; i < d_size.height; i += di )
{ {

Loading…
Cancel
Save