@ -112,6 +112,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
int j = 0 ;
for ( ; j < = blockSize - 4 ; j + = 4 )
{
int k = 0 ;
const float * rptr = rowbuf + j * vecsize_aligned ;
__m256 vs00 = _mm256_setzero_ps ( ) , vs01 = _mm256_setzero_ps ( ) ,
@ -121,7 +122,65 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
vs20 = _mm256_setzero_ps ( ) , vs21 = _mm256_setzero_ps ( ) ,
vs22 = _mm256_setzero_ps ( ) , vs23 = _mm256_setzero_ps ( ) ;
for ( int k = 0 ; k < vecsize ; k + = 8 , rptr + = 8 )
# if CV_AVX512_SKX // AVX512VL is necessary to avoid register spilling
if ( vecsize > = 32 )
{
__m512 vs00_5 = _mm512_setzero_ps ( ) , vs01_5 = _mm512_setzero_ps ( ) ,
vs02_5 = _mm512_setzero_ps ( ) , vs03_5 = _mm512_setzero_ps ( ) ,
vs10_5 = _mm512_setzero_ps ( ) , vs11_5 = _mm512_setzero_ps ( ) ,
vs12_5 = _mm512_setzero_ps ( ) , vs13_5 = _mm512_setzero_ps ( ) ,
vs20_5 = _mm512_setzero_ps ( ) , vs21_5 = _mm512_setzero_ps ( ) ,
vs22_5 = _mm512_setzero_ps ( ) , vs23_5 = _mm512_setzero_ps ( ) ;
for ( ; k < = vecsize - 16 ; k + = 16 , rptr + = 16 )
{
__m512 w0 = _mm512_loadu_ps ( wptr0 + k ) ;
__m512 w1 = _mm512_loadu_ps ( wptr1 + k ) ;
__m512 w2 = _mm512_loadu_ps ( wptr2 + k ) ;
__m512 r0 = _mm512_loadu_ps ( rptr ) ;
vs00_5 = _mm512_fmadd_ps ( w0 , r0 , vs00_5 ) ;
vs10_5 = _mm512_fmadd_ps ( w1 , r0 , vs10_5 ) ;
vs20_5 = _mm512_fmadd_ps ( w2 , r0 , vs20_5 ) ;
r0 = _mm512_loadu_ps ( rptr + vecsize_aligned ) ;
vs01_5 = _mm512_fmadd_ps ( w0 , r0 , vs01_5 ) ;
vs11_5 = _mm512_fmadd_ps ( w1 , r0 , vs11_5 ) ;
vs21_5 = _mm512_fmadd_ps ( w2 , r0 , vs21_5 ) ;
r0 = _mm512_loadu_ps ( rptr + vecsize_aligned * 2 ) ;
vs02_5 = _mm512_fmadd_ps ( w0 , r0 , vs02_5 ) ;
vs12_5 = _mm512_fmadd_ps ( w1 , r0 , vs12_5 ) ;
vs22_5 = _mm512_fmadd_ps ( w2 , r0 , vs22_5 ) ;
r0 = _mm512_loadu_ps ( rptr + vecsize_aligned * 3 ) ;
vs03_5 = _mm512_fmadd_ps ( w0 , r0 , vs03_5 ) ;
vs13_5 = _mm512_fmadd_ps ( w1 , r0 , vs13_5 ) ;
vs23_5 = _mm512_fmadd_ps ( w2 , r0 , vs23_5 ) ;
}
/*
* now fold the 512 bit accumulator vectors into 256 bit vectors so that the AVX2 code can finish
* the tail of the vector
*/
vs00 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs00_5 , 0 ) , _mm512_extractf32x8_ps ( vs00_5 , 1 ) ) ;
vs10 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs10_5 , 0 ) , _mm512_extractf32x8_ps ( vs10_5 , 1 ) ) ;
vs20 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs20_5 , 0 ) , _mm512_extractf32x8_ps ( vs20_5 , 1 ) ) ;
vs01 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs01_5 , 0 ) , _mm512_extractf32x8_ps ( vs01_5 , 1 ) ) ;
vs11 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs11_5 , 0 ) , _mm512_extractf32x8_ps ( vs11_5 , 1 ) ) ;
vs21 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs21_5 , 0 ) , _mm512_extractf32x8_ps ( vs21_5 , 1 ) ) ;
vs02 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs02_5 , 0 ) , _mm512_extractf32x8_ps ( vs02_5 , 1 ) ) ;
vs12 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs12_5 , 0 ) , _mm512_extractf32x8_ps ( vs12_5 , 1 ) ) ;
vs22 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs22_5 , 0 ) , _mm512_extractf32x8_ps ( vs22_5 , 1 ) ) ;
vs03 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs03_5 , 0 ) , _mm512_extractf32x8_ps ( vs03_5 , 1 ) ) ;
vs13 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs13_5 , 0 ) , _mm512_extractf32x8_ps ( vs13_5 , 1 ) ) ;
vs23 = _mm256_add_ps ( _mm512_extractf32x8_ps ( vs23_5 , 0 ) , _mm512_extractf32x8_ps ( vs23_5 , 1 ) ) ;
}
# endif
for ( ; k < vecsize ; k + = 8 , rptr + = 8 )
{
__m256 w0 = _mm256_load_ps ( wptr0 + k ) ;
__m256 w1 = _mm256_load_ps ( wptr1 + k ) ;