@ -50,6 +50,16 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
const float * rowbuf , float * output , const int * outShape ,
int blockSize , int vecsize , int vecsize_aligned ,
const float * relu , bool initOutput ) ;
void fastDepthwiseConv ( const float * weights ,
int kernel_h , int kernel_w ,
int stride_h , int stride_w ,
int dilation_h , int dilation_w ,
int pad_t , int pad_l ,
const float * bias , const float * relu ,
const float * inptr ,
int height , int width ,
float * outptr ,
int out_d , int outH , int outW ) ;
void fastGEMM1T ( const float * vec , const float * weights ,
size_t wstep , const float * bias ,
float * dst , int nvecs , int vecsize ) ;
@ -64,6 +74,8 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr,
# define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b))
# endif
enum { FASCONV_BASE_VECSZ = 4 } ;
void fastConv ( const float * weights , size_t wstep , const float * bias ,
const float * rowbuf , float * output , const int * outShape ,
int blockSize , int vecsize , int vecsize_aligned ,
@ -73,6 +85,11 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
size_t outPlaneSize = outShape [ 2 ] * outShape [ 3 ] ;
float r0 = 1.f , r1 = 1.f , r2 = 1.f ;
__m128 vr0 = _mm_set1_ps ( 1.f ) , vr1 = vr0 , vr2 = vr0 , z = _mm_setzero_ps ( ) ;
int CV_DECL_ALIGNED ( 16 ) maskbuf [ FASCONV_BASE_VECSZ ] = { 0 } ;
int rsz = blockSize % FASCONV_BASE_VECSZ ;
for ( int i = 0 ; i < rsz ; i + + )
maskbuf [ FASCONV_BASE_VECSZ - i - 1 ] = - 1 ;
__m128 mask = _mm_loadu_ps ( ( const float * ) maskbuf ) ;
// now compute dot product of the weights
// and im2row-transformed part of the tensor
@ -114,8 +131,16 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
}
int j = 0 ;
for ( ; j < = blockSize - 4 ; j + = 4 )
for ( ; j < blockSize ; j + = FASCONV_BASE_VECSZ )
{
bool tail = false ;
if ( j + FASCONV_BASE_VECSZ > blockSize )
{
if ( j = = 0 )
break ;
j = blockSize - FASCONV_BASE_VECSZ ;
tail = true ;
}
int k = 0 ;
const float * rptr = rowbuf + j * vecsize_aligned ;
@ -243,9 +268,16 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
__m128 m0 = _mm_cmp_ps ( s0 , z , _CMP_GT_OS ) ;
__m128 m1 = _mm_cmp_ps ( s1 , z , _CMP_GT_OS ) ;
__m128 m2 = _mm_cmp_ps ( s2 , z , _CMP_GT_OS ) ;
s0 = _mm_xor_ps ( s0 , _mm_andnot_ps ( m0 , _mm_xor_ps ( _mm_mul_ps ( s0 , vr0 ) , s0 ) ) ) ;
s1 = _mm_xor_ps ( s1 , _mm_andnot_ps ( m1 , _mm_xor_ps ( _mm_mul_ps ( s1 , vr1 ) , s1 ) ) ) ;
s2 = _mm_xor_ps ( s2 , _mm_andnot_ps ( m2 , _mm_xor_ps ( _mm_mul_ps ( s2 , vr2 ) , s2 ) ) ) ;
s0 = _mm_blendv_ps ( _mm_mul_ps ( s0 , vr0 ) , s0 , m0 ) ;
s1 = _mm_blendv_ps ( _mm_mul_ps ( s1 , vr1 ) , s1 , m1 ) ;
s2 = _mm_blendv_ps ( _mm_mul_ps ( s2 , vr2 ) , s2 , m2 ) ;
}
if ( tail )
{
s0 = _mm_blendv_ps ( _mm_loadu_ps ( outptr0 + j ) , s0 , mask ) ;
s1 = _mm_blendv_ps ( _mm_loadu_ps ( outptr1 + j ) , s1 , mask ) ;
s2 = _mm_blendv_ps ( _mm_loadu_ps ( outptr2 + j ) , s2 , mask ) ;
}
_mm_storeu_ps ( outptr0 + j , s0 ) ;
@ -253,9 +285,55 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
_mm_storeu_ps ( outptr2 + j , s2 ) ;
}
for ( ; j < = blockSize - 2 ; j + = 2 )
{
const float * rptr0 = rowbuf + j * vecsize_aligned ;
const float * rptr1 = rowbuf + ( j + 1 ) * vecsize_aligned ;
float s00 , s01 , s10 , s11 , s20 , s21 ;
if ( initOutput )
{
s00 = s01 = bias0 ;
s10 = s11 = bias1 ;
s20 = s21 = bias2 ;
}
else
{
s00 = outptr0 [ j ] ; s01 = outptr0 [ j + 1 ] ;
s10 = outptr1 [ j ] ; s11 = outptr1 [ j + 1 ] ;
s20 = outptr2 [ j ] ; s21 = outptr2 [ j + 1 ] ;
}
for ( int k = 0 ; k < vecsize ; k + + )
{
float w0 = wptr0 [ k ] , w1 = wptr1 [ k ] , w2 = wptr2 [ k ] ;
float r = rptr0 [ k ] ;
s00 + = w0 * r ; s10 + = w1 * r ; s20 + = w2 * r ;
r = rptr1 [ k ] ;
s01 + = w0 * r ; s11 + = w1 * r ; s21 + = w2 * r ;
}
if ( relu )
{
s00 = s00 > 0.f ? s00 : s00 * r0 ;
s01 = s01 > 0.f ? s01 : s01 * r0 ;
s10 = s10 > 0.f ? s10 : s10 * r1 ;
s11 = s11 > 0.f ? s11 : s11 * r1 ;
s20 = s20 > 0.f ? s20 : s20 * r2 ;
s21 = s21 > 0.f ? s21 : s21 * r2 ;
}
outptr0 [ j ] = s00 ;
outptr0 [ j + 1 ] = s01 ;
outptr1 [ j ] = s10 ;
outptr1 [ j + 1 ] = s11 ;
outptr2 [ j ] = s20 ;
outptr2 [ j + 1 ] = s21 ;
}
for ( ; j < blockSize ; j + + )
{
const float * rptr = rowbuf + j * vecsize_aligned ;
const float * rptr0 = rowbuf + j * vecsize_aligned ;
float s00 , s10 , s20 ;
if ( initOutput )
@ -273,10 +351,9 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
for ( int k = 0 ; k < vecsize ; k + + )
{
float r0 = rptr [ k ] ;
s00 + = wptr0 [ k ] * r0 ;
s10 + = wptr1 [ k ] * r0 ;
s20 + = wptr2 [ k ] * r0 ;
float w0 = wptr0 [ k ] , w1 = wptr1 [ k ] , w2 = wptr2 [ k ] ;
float r = rptr0 [ k ] ;
s00 + = w0 * r ; s10 + = w1 * r ; s20 + = w2 * r ;
}
if ( relu )
@ -294,6 +371,185 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
_mm256_zeroupper ( ) ;
}
static inline void _mm256_load_deinterleave ( const float * ptr , __m256 & a , __m256 & b )
{
__m256 t0 = _mm256_loadu_ps ( ptr ) ;
__m256 t1 = _mm256_loadu_ps ( ptr + 8 ) ;
__m256 lo = _mm256_permute2f128_ps ( t0 , t1 , 0 + 2 * 16 ) ;
__m256 hi = _mm256_permute2f128_ps ( t0 , t1 , 1 + 3 * 16 ) ;
a = _mm256_shuffle_ps ( lo , hi , 0x88 ) ;
b = _mm256_shuffle_ps ( lo , hi , 0xdd ) ;
}
void fastDepthwiseConv ( const float * wptr ,
int kernel_h , int kernel_w ,
int stride_h , int stride_w ,
int dilation_h , int dilation_w ,
int pad_t , int pad_l ,
const float * biasptr , const float * relu ,
const float * inptr_ ,
int height , int width ,
float * outptr_ ,
int out_d , int outH , int outW )
{
const float w00_ = wptr [ 0 ] , w01_ = wptr [ 1 ] , w02_ = wptr [ 2 ] ,
w10 = wptr [ 3 ] , w11 = wptr [ 4 ] , w12 = wptr [ 5 ] ,
w20_ = wptr [ 6 ] , w21_ = wptr [ 7 ] , w22_ = wptr [ 8 ] ;
int outW1 = min ( outW , ( width - dilation_w * ( kernel_w - 1 ) + pad_l ) / stride_w ) ;
float relu_coeff = relu ? relu [ out_d ] : 1.f , bias = biasptr [ out_d ] ;
for ( int out_i = 0 ; out_i < outH ; out_i + + )
{
int in_i = out_i * stride_h - pad_t , out_j = 0 ;
const float * imgptr0 = inptr_ + in_i * width ;
const float * imgptr1 = imgptr0 + dilation_h * width ;
const float * imgptr2 = imgptr0 + ( dilation_h * 2 ) * width ;
float out , w00 = w00_ , w01 = w01_ , w02 = w02_ ;
float w20 = w20_ , w21 = w21_ , w22 = w22_ ;
if ( in_i < 0 )
{
w00 = w01 = w02 = 0.f ;
imgptr0 = imgptr1 ;
}
else if ( in_i + dilation_h * ( kernel_h - 1 ) > = height )
{
w20 = w21 = w22 = 0.f ;
imgptr2 = imgptr1 ;
}
float * outptr = outptr_ + out_i * outW ;
if ( pad_l > 0 )
{
out = imgptr0 [ 0 ] * w01 + imgptr0 [ dilation_w ] * w02 +
imgptr1 [ 0 ] * w11 + imgptr1 [ dilation_w ] * w12 +
imgptr2 [ 0 ] * w21 + imgptr2 [ dilation_w ] * w22 + bias ;
if ( relu )
out = out > 0.f ? out : out * relu_coeff ;
outptr [ 0 ] = out ;
out_j = 1 ;
}
if ( stride_w = = 1 | | ( stride_w = = 2 & & dilation_w = = 1 ) )
{
const int VECSZ = 8 ;
__m256 vw00 = _mm256_set1_ps ( w00 ) , vw01 = _mm256_set1_ps ( w01 ) , vw02 = _mm256_set1_ps ( w02 ) ,
vw10 = _mm256_set1_ps ( w10 ) , vw11 = _mm256_set1_ps ( w11 ) , vw12 = _mm256_set1_ps ( w12 ) ,
vw20 = _mm256_set1_ps ( w20 ) , vw21 = _mm256_set1_ps ( w21 ) , vw22 = _mm256_set1_ps ( w22 ) ;
__m256 z = _mm256_setzero_ps ( ) , vbias = _mm256_set1_ps ( bias ) , vrc = _mm256_set1_ps ( relu_coeff ) ;
if ( stride_w = = 1 )
for ( ; out_j < outW1 ; out_j + = VECSZ )
{
if ( out_j + VECSZ > outW1 & & out_j > pad_l )
out_j = outW1 - VECSZ ;
int in_j = out_j * stride_w - pad_l ;
__m256 v00 = _mm256_loadu_ps ( imgptr0 + in_j ) ,
v01 = _mm256_loadu_ps ( imgptr0 + in_j + dilation_w ) ,
v02 = _mm256_loadu_ps ( imgptr0 + in_j + dilation_w * 2 ) ,
v10 = _mm256_loadu_ps ( imgptr1 + in_j ) ,
v11 = _mm256_loadu_ps ( imgptr1 + in_j + dilation_w ) ,
v12 = _mm256_loadu_ps ( imgptr1 + in_j + dilation_w * 2 ) ,
v20 = _mm256_loadu_ps ( imgptr2 + in_j ) ,
v21 = _mm256_loadu_ps ( imgptr2 + in_j + dilation_w ) ,
v22 = _mm256_loadu_ps ( imgptr2 + in_j + dilation_w * 2 ) ;
__m256 vout0 = _mm256_fmadd_ps ( v00 , vw00 , vbias ) ;
__m256 vout1 = _mm256_mul_ps ( v01 , vw01 ) ;
__m256 vout2 = _mm256_mul_ps ( v02 , vw02 ) ;
vout0 = _mm256_fmadd_ps ( v10 , vw10 , vout0 ) ;
vout1 = _mm256_fmadd_ps ( v11 , vw11 , vout1 ) ;
vout2 = _mm256_fmadd_ps ( v12 , vw12 , vout2 ) ;
vout0 = _mm256_fmadd_ps ( v20 , vw20 , vout0 ) ;
vout1 = _mm256_fmadd_ps ( v21 , vw21 , vout1 ) ;
vout2 = _mm256_fmadd_ps ( v22 , vw22 , vout2 ) ;
vout0 = _mm256_add_ps ( _mm256_add_ps ( vout0 , vout1 ) , vout2 ) ;
if ( relu )
{
__m256 m = _mm256_cmp_ps ( vout0 , z , _CMP_GT_OQ ) ;
vout0 = _mm256_blendv_ps ( _mm256_mul_ps ( vout0 , vrc ) , vout0 , m ) ;
}
_mm256_storeu_ps ( outptr + out_j , vout0 ) ;
}
else
for ( ; out_j < outW1 ; out_j + = VECSZ )
{
if ( out_j + VECSZ > outW1 & & out_j > pad_l )
out_j = outW1 - VECSZ ;
int in_j = out_j * stride_w - pad_l ;
__m256 v00 , v01 , v02 , v10 , v11 , v12 , v20 , v21 , v22 , unused ;
_mm256_load_deinterleave ( imgptr0 + in_j , v00 , v01 ) ;
_mm256_load_deinterleave ( imgptr0 + in_j + 2 , v02 , unused ) ;
_mm256_load_deinterleave ( imgptr1 + in_j , v10 , v11 ) ;
_mm256_load_deinterleave ( imgptr1 + in_j + 2 , v12 , unused ) ;
_mm256_load_deinterleave ( imgptr2 + in_j , v20 , v21 ) ;
_mm256_load_deinterleave ( imgptr2 + in_j + 2 , v22 , unused ) ;
__m256 vout0 = _mm256_fmadd_ps ( v00 , vw00 , vbias ) ;
__m256 vout1 = _mm256_mul_ps ( v01 , vw01 ) ;
__m256 vout2 = _mm256_mul_ps ( v02 , vw02 ) ;
vout0 = _mm256_fmadd_ps ( v10 , vw10 , vout0 ) ;
vout1 = _mm256_fmadd_ps ( v11 , vw11 , vout1 ) ;
vout2 = _mm256_fmadd_ps ( v12 , vw12 , vout2 ) ;
vout0 = _mm256_fmadd_ps ( v20 , vw20 , vout0 ) ;
vout1 = _mm256_fmadd_ps ( v21 , vw21 , vout1 ) ;
vout2 = _mm256_fmadd_ps ( v22 , vw22 , vout2 ) ;
vout0 = _mm256_add_ps ( _mm256_add_ps ( vout0 , vout1 ) , vout2 ) ;
if ( relu )
{
__m256 m = _mm256_cmp_ps ( vout0 , z , _CMP_GT_OQ ) ;
vout0 = _mm256_blendv_ps ( _mm256_mul_ps ( vout0 , vrc ) , vout0 , m ) ;
}
_mm256_storeu_ps ( outptr + out_j , vout0 ) ;
}
}
for ( ; out_j < outW1 ; out_j + + )
{
int in_j = out_j * stride_w - pad_l ;
out = imgptr0 [ in_j ] * w00 + imgptr0 [ in_j + dilation_w ] * w01 + imgptr0 [ in_j + dilation_w * 2 ] * w02 +
imgptr1 [ in_j ] * w10 + imgptr1 [ in_j + dilation_w ] * w11 + imgptr1 [ in_j + dilation_w * 2 ] * w12 +
imgptr2 [ in_j ] * w20 + imgptr2 [ in_j + dilation_w ] * w21 + imgptr2 [ in_j + dilation_w * 2 ] * w22 + bias ;
if ( relu )
out = out > 0.f ? out : out * relu_coeff ;
outptr [ out_j ] = out ;
}
for ( ; out_j < outW ; out_j + + )
{
int in_j0 = out_j * stride_w - pad_l , in_j1 = in_j0 + dilation_w , in_j2 = in_j0 + dilation_w * 2 ;
float s0 = 1.f , s1 = 1.f , s2 = 1.f ;
if ( in_j0 > = width )
{
in_j0 = 0 ;
s0 = 0.f ;
}
if ( in_j1 > = width )
{
in_j1 = 0 ;
s1 = 0.f ;
}
if ( in_j2 > = width )
{
in_j2 = 0 ;
s2 = 0.f ;
}
out = imgptr0 [ in_j0 ] * w00 * s0 + imgptr0 [ in_j1 ] * w01 * s1 + imgptr0 [ in_j2 ] * w02 * s2 +
imgptr1 [ in_j0 ] * w10 * s0 + imgptr1 [ in_j1 ] * w11 * s1 + imgptr1 [ in_j2 ] * w12 * s2 +
imgptr2 [ in_j0 ] * w20 * s0 + imgptr2 [ in_j1 ] * w21 * s1 + imgptr2 [ in_j2 ] * w22 * s2 + bias ;
if ( relu )
out = out > 0.f ? out : out * relu_coeff ;
outptr [ out_j ] = out ;
}
}
_mm256_zeroupper ( ) ;
}
// dst = vec * weights^t + bias
void fastGEMM1T ( const float * vec , const float * weights ,
size_t wstep , const float * bias ,