@ -72,7 +72,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
int outCn = outShape [ 1 ] ;
size_t outPlaneSize = outShape [ 2 ] * outShape [ 3 ] ;
float r0 = 1.f , r1 = 1.f , r2 = 1.f ;
__m256 vr0 = _mm256 _set1_ps ( 1.f ) , vr1 = vr0 , vr2 = vr0 , z = _mm256 _setzero_ps ( ) ;
__m128 vr0 = _mm_set1_ps ( 1.f ) , vr1 = vr0 , vr2 = vr0 , z = _mm_setzero_ps ( ) ;
// now compute dot product of the weights
// and im2row-transformed part of the tensor
@ -104,9 +104,9 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
r0 = relu [ i ] ;
r1 = relu [ i + 1 ] ;
r2 = relu [ i + 2 ] ;
vr0 = _mm256 _set1_ps ( r0 ) ;
vr1 = _mm256 _set1_ps ( r1 ) ;
vr2 = _mm256 _set1_ps ( r2 ) ;
vr0 = _mm_set1_ps ( r0 ) ;
vr1 = _mm_set1_ps ( r1 ) ;
vr2 = _mm_set1_ps ( r2 ) ;
}
int j = 0 ;
@ -156,38 +156,38 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
t1 = _mm256_add_ps ( t1 , _mm256_permute2f128_ps ( t1 , t1 , 1 ) ) ;
t2 = _mm256_add_ps ( t2 , _mm256_permute2f128_ps ( t2 , t2 , 1 ) ) ;
__m256 s0 , s1 , s2 ;
__m128 s0 , s1 , s2 ;
if ( initOutput )
{
s0 = _mm256 _set1_ps ( bias0 ) ;
s1 = _mm256 _set1_ps ( bias1 ) ;
s2 = _mm256 _set1_ps ( bias2 ) ;
s0 = _mm_set1_ps ( bias0 ) ;
s1 = _mm_set1_ps ( bias1 ) ;
s2 = _mm_set1_ps ( bias2 ) ;
}
else
{
s0 = _mm256_castps128_ps256 ( _mm _loadu_ps ( outptr0 + j ) ) ;
s1 = _mm256_castps128_ps256 ( _mm _loadu_ps ( outptr1 + j ) ) ;
s2 = _mm256_castps128_ps256 ( _mm _loadu_ps ( outptr2 + j ) ) ;
s0 = _mm_loadu_ps ( outptr0 + j ) ;
s1 = _mm_loadu_ps ( outptr1 + j ) ;
s2 = _mm_loadu_ps ( outptr2 + j ) ;
}
s0 = _mm256 _add_ps ( s0 , t0 ) ;
s1 = _mm256 _add_ps ( s1 , t1 ) ;
s2 = _mm256 _add_ps ( s2 , t2 ) ;
s0 = _mm_add_ps ( s0 , _mm256_castps256_ps128 ( t0 ) ) ;
s1 = _mm_add_ps ( s1 , _mm256_castps256_ps128 ( t1 ) ) ;
s2 = _mm_add_ps ( s2 , _mm256_castps256_ps128 ( t2 ) ) ;
if ( relu )
{
__m256 m0 = _mm256 _cmp_ps ( s0 , z , _CMP_GT_OS ) ;
__m256 m1 = _mm256 _cmp_ps ( s1 , z , _CMP_GT_OS ) ;
__m256 m2 = _mm256 _cmp_ps ( s2 , z , _CMP_GT_OS ) ;
s0 = _mm256 _xor_ps ( s0 , _mm256 _andnot_ps ( m0 , _mm256 _xor_ps ( _mm256 _mul_ps ( s0 , vr0 ) , s0 ) ) ) ;
s1 = _mm256 _xor_ps ( s1 , _mm256 _andnot_ps ( m1 , _mm256 _xor_ps ( _mm256 _mul_ps ( s1 , vr1 ) , s1 ) ) ) ;
s2 = _mm256 _xor_ps ( s2 , _mm256 _andnot_ps ( m2 , _mm256 _xor_ps ( _mm256 _mul_ps ( s2 , vr2 ) , s2 ) ) ) ;
__m128 m0 = _mm_cmp_ps ( s0 , z , _CMP_GT_OS ) ;
__m128 m1 = _mm_cmp_ps ( s1 , z , _CMP_GT_OS ) ;
__m128 m2 = _mm_cmp_ps ( s2 , z , _CMP_GT_OS ) ;
s0 = _mm_xor_ps ( s0 , _mm_andnot_ps ( m0 , _mm_xor_ps ( _mm_mul_ps ( s0 , vr0 ) , s0 ) ) ) ;
s1 = _mm_xor_ps ( s1 , _mm_andnot_ps ( m1 , _mm_xor_ps ( _mm_mul_ps ( s1 , vr1 ) , s1 ) ) ) ;
s2 = _mm_xor_ps ( s2 , _mm_andnot_ps ( m2 , _mm_xor_ps ( _mm_mul_ps ( s2 , vr2 ) , s2 ) ) ) ;
}
_mm_storeu_ps ( outptr0 + j , _mm256_castps256_ps128 ( s0 ) ) ;
_mm_storeu_ps ( outptr1 + j , _mm256_castps256_p s128 ( s1 ) ) ;
_mm_storeu_ps ( outptr2 + j , _mm256_castp s256_ps128 ( s2 ) ) ;
_mm_storeu_ps ( outptr0 + j , s0 ) ;
_mm_storeu_ps ( outptr1 + j , s1 ) ;
_mm_storeu_ps ( outptr2 + j , s2 ) ;
}
for ( ; j < blockSize ; j + + )
@ -294,11 +294,63 @@ void fastGEMM1T( const float* vec, const float* weights,
_mm256_zeroupper ( ) ;
}
void fastGEMM ( const float * aptr , size_t astep , const float * bptr ,
size_t bstep , float * cptr , size_t cstep ,
int ma , int na , int nb )
{
int n = 0 ;
# ifdef CV_AVX512
for ( ; n < = nb - 32 ; n + = 32 )
{
for ( int m = 0 ; m < ma ; m + = 4 )
{
const float * aptr0 = aptr + astep * m ;
const float * aptr1 = aptr + astep * std : : min ( m + 1 , ma - 1 ) ;
const float * aptr2 = aptr + astep * std : : min ( m + 2 , ma - 1 ) ;
const float * aptr3 = aptr + astep * std : : min ( m + 3 , ma - 1 ) ;
float * cptr0 = cptr + cstep * m ;
float * cptr1 = cptr + cstep * std : : min ( m + 1 , ma - 1 ) ;
float * cptr2 = cptr + cstep * std : : min ( m + 2 , ma - 1 ) ;
float * cptr3 = cptr + cstep * std : : min ( m + 3 , ma - 1 ) ;
__m512 d00 = _mm512_setzero_ps ( ) , d01 = _mm512_setzero_ps ( ) ;
__m512 d10 = _mm512_setzero_ps ( ) , d11 = _mm512_setzero_ps ( ) ;
__m512 d20 = _mm512_setzero_ps ( ) , d21 = _mm512_setzero_ps ( ) ;
__m512 d30 = _mm512_setzero_ps ( ) , d31 = _mm512_setzero_ps ( ) ;
for ( int k = 0 ; k < na ; k + + )
{
__m512 a0 = _mm512_set1_ps ( aptr0 [ k ] ) ;
__m512 a1 = _mm512_set1_ps ( aptr1 [ k ] ) ;
__m512 a2 = _mm512_set1_ps ( aptr2 [ k ] ) ;
__m512 a3 = _mm512_set1_ps ( aptr3 [ k ] ) ;
__m512 b0 = _mm512_loadu_ps ( bptr + k * bstep + n ) ;
__m512 b1 = _mm512_loadu_ps ( bptr + k * bstep + n + 16 ) ;
d00 = _mm512_fmadd_ps ( a0 , b0 , d00 ) ;
d01 = _mm512_fmadd_ps ( a0 , b1 , d01 ) ;
d10 = _mm512_fmadd_ps ( a1 , b0 , d10 ) ;
d11 = _mm512_fmadd_ps ( a1 , b1 , d11 ) ;
d20 = _mm512_fmadd_ps ( a2 , b0 , d20 ) ;
d21 = _mm512_fmadd_ps ( a2 , b1 , d21 ) ;
d30 = _mm512_fmadd_ps ( a3 , b0 , d30 ) ;
d31 = _mm512_fmadd_ps ( a3 , b1 , d31 ) ;
}
_mm512_storeu_ps ( cptr0 + n , d00 ) ;
_mm512_storeu_ps ( cptr0 + n + 16 , d01 ) ;
_mm512_storeu_ps ( cptr1 + n , d10 ) ;
_mm512_storeu_ps ( cptr1 + n + 16 , d11 ) ;
_mm512_storeu_ps ( cptr2 + n , d20 ) ;
_mm512_storeu_ps ( cptr2 + n + 16 , d21 ) ;
_mm512_storeu_ps ( cptr3 + n , d30 ) ;
_mm512_storeu_ps ( cptr3 + n + 16 , d31 ) ;
}
}
# endif
for ( ; n < = nb - 16 ; n + = 16 )
{
for ( int m = 0 ; m < ma ; m + = 4 )