|
|
|
@ -46,10 +46,6 @@ namespace cv { |
|
|
|
|
namespace dnn { |
|
|
|
|
CV_CPU_OPTIMIZATION_NAMESPACE_BEGIN |
|
|
|
|
|
|
|
|
|
void fastConv( const float* weights, size_t wstep, const float* bias, |
|
|
|
|
const float* rowbuf, float* output, const int* outShape, |
|
|
|
|
int blockSize, int vecsize, int vecsize_aligned, |
|
|
|
|
const float* relu, bool initOutput ); |
|
|
|
|
void fastDepthwiseConv( const float* weights, |
|
|
|
|
int kernel_h, int kernel_w, |
|
|
|
|
int stride_h, int stride_w, |
|
|
|
@ -74,305 +70,6 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr, |
|
|
|
|
#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b)) |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
enum { FASCONV_BASE_VECSZ = 4 }; |
|
|
|
|
|
|
|
|
|
void fastConv( const float* weights, size_t wstep, const float* bias, |
|
|
|
|
const float* rowbuf, float* output, const int* outShape, |
|
|
|
|
int blockSize, int vecsize, int vecsize_aligned, |
|
|
|
|
const float* relu, bool initOutput ) |
|
|
|
|
{ |
|
|
|
|
CV_Assert(isAligned<32>(weights)); |
|
|
|
|
|
|
|
|
|
int outCn = outShape[1]; |
|
|
|
|
size_t outPlaneSize = outShape[2]*outShape[3]; |
|
|
|
|
float r0 = 1.f, r1 = 1.f, r2 = 1.f; |
|
|
|
|
__m128 vr0 = _mm_set1_ps(1.f), vr1 = vr0, vr2 = vr0, z = _mm_setzero_ps(); |
|
|
|
|
int CV_DECL_ALIGNED(16) maskbuf[FASCONV_BASE_VECSZ] = {0}; |
|
|
|
|
int rsz = blockSize % FASCONV_BASE_VECSZ; |
|
|
|
|
for( int i = 0; i < rsz; i++ ) |
|
|
|
|
maskbuf[FASCONV_BASE_VECSZ - i - 1] = -1; |
|
|
|
|
__m128 mask = _mm_loadu_ps((const float*)maskbuf); |
|
|
|
|
|
|
|
|
|
// now compute dot product of the weights
|
|
|
|
|
// and im2row-transformed part of the tensor
|
|
|
|
|
for( int i = 0; i < outCn; i += 3 ) |
|
|
|
|
{ |
|
|
|
|
const float* wptr0 = weights + i*wstep; |
|
|
|
|
const float* wptr1 = wptr0 + wstep; |
|
|
|
|
const float* wptr2 = wptr1 + wstep; |
|
|
|
|
float* outptr0 = output + i*outPlaneSize; |
|
|
|
|
float* outptr1 = outptr0 + outPlaneSize; |
|
|
|
|
float* outptr2 = outptr1 + outPlaneSize; |
|
|
|
|
float bias0 = bias[i], bias1 = bias[i+1], bias2 = bias[i+2]; |
|
|
|
|
|
|
|
|
|
if( i+2 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
wptr2 = wptr1; |
|
|
|
|
outptr2 = outptr1; |
|
|
|
|
bias2 = bias1; |
|
|
|
|
if( i+1 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
wptr2 = wptr1 = wptr0; |
|
|
|
|
outptr2 = outptr1 = outptr0; |
|
|
|
|
bias2 = bias1 = bias0; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
r0 = relu[i]; r1 = relu[i+1]; r2 = relu[i+2]; |
|
|
|
|
if( i+2 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
r2 = r1; |
|
|
|
|
if( i+1 >= outCn ) |
|
|
|
|
r2 = r1 = r0; |
|
|
|
|
} |
|
|
|
|
vr0 = _mm_set1_ps(r0); |
|
|
|
|
vr1 = _mm_set1_ps(r1); |
|
|
|
|
vr2 = _mm_set1_ps(r2); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int j = 0; |
|
|
|
|
for( ; j < blockSize; j += FASCONV_BASE_VECSZ ) |
|
|
|
|
{ |
|
|
|
|
bool tail = false; |
|
|
|
|
if (j + FASCONV_BASE_VECSZ > blockSize) |
|
|
|
|
{ |
|
|
|
|
if (j == 0) |
|
|
|
|
break; |
|
|
|
|
j = blockSize - FASCONV_BASE_VECSZ; |
|
|
|
|
tail = true; |
|
|
|
|
} |
|
|
|
|
int k = 0; |
|
|
|
|
const float* rptr = rowbuf + j*vecsize_aligned; |
|
|
|
|
|
|
|
|
|
__m256 vs00 = _mm256_setzero_ps(), vs01 = _mm256_setzero_ps(), |
|
|
|
|
vs02 = _mm256_setzero_ps(), vs03 = _mm256_setzero_ps(), |
|
|
|
|
vs10 = _mm256_setzero_ps(), vs11 = _mm256_setzero_ps(), |
|
|
|
|
vs12 = _mm256_setzero_ps(), vs13 = _mm256_setzero_ps(), |
|
|
|
|
vs20 = _mm256_setzero_ps(), vs21 = _mm256_setzero_ps(), |
|
|
|
|
vs22 = _mm256_setzero_ps(), vs23 = _mm256_setzero_ps(); |
|
|
|
|
|
|
|
|
|
#if CV_AVX512_SKX // AVX512VL is necessary to avoid register spilling
|
|
|
|
|
if (vecsize >= 32) |
|
|
|
|
{ |
|
|
|
|
__m512 vs00_5 = _mm512_setzero_ps(), vs01_5 = _mm512_setzero_ps(), |
|
|
|
|
vs02_5 = _mm512_setzero_ps(), vs03_5 = _mm512_setzero_ps(), |
|
|
|
|
vs10_5 = _mm512_setzero_ps(), vs11_5 = _mm512_setzero_ps(), |
|
|
|
|
vs12_5 = _mm512_setzero_ps(), vs13_5 = _mm512_setzero_ps(), |
|
|
|
|
vs20_5 = _mm512_setzero_ps(), vs21_5 = _mm512_setzero_ps(), |
|
|
|
|
vs22_5 = _mm512_setzero_ps(), vs23_5 = _mm512_setzero_ps(); |
|
|
|
|
|
|
|
|
|
for (; k <= vecsize - 16; k += 16, rptr += 16) |
|
|
|
|
{ |
|
|
|
|
__m512 w0 = _mm512_loadu_ps(wptr0 + k); |
|
|
|
|
__m512 w1 = _mm512_loadu_ps(wptr1 + k); |
|
|
|
|
__m512 w2 = _mm512_loadu_ps(wptr2 + k); |
|
|
|
|
__m512 r0 = _mm512_loadu_ps(rptr); |
|
|
|
|
|
|
|
|
|
vs00_5 = _mm512_fmadd_ps(w0, r0, vs00_5); |
|
|
|
|
vs10_5 = _mm512_fmadd_ps(w1, r0, vs10_5); |
|
|
|
|
vs20_5 = _mm512_fmadd_ps(w2, r0, vs20_5); |
|
|
|
|
|
|
|
|
|
r0 = _mm512_loadu_ps(rptr + vecsize_aligned); |
|
|
|
|
vs01_5 = _mm512_fmadd_ps(w0, r0, vs01_5); |
|
|
|
|
vs11_5 = _mm512_fmadd_ps(w1, r0, vs11_5); |
|
|
|
|
vs21_5 = _mm512_fmadd_ps(w2, r0, vs21_5); |
|
|
|
|
|
|
|
|
|
r0 = _mm512_loadu_ps(rptr + vecsize_aligned*2); |
|
|
|
|
vs02_5 = _mm512_fmadd_ps(w0, r0, vs02_5); |
|
|
|
|
vs12_5 = _mm512_fmadd_ps(w1, r0, vs12_5); |
|
|
|
|
vs22_5 = _mm512_fmadd_ps(w2, r0, vs22_5); |
|
|
|
|
|
|
|
|
|
r0 = _mm512_loadu_ps(rptr + vecsize_aligned*3); |
|
|
|
|
vs03_5 = _mm512_fmadd_ps(w0, r0, vs03_5); |
|
|
|
|
vs13_5 = _mm512_fmadd_ps(w1, r0, vs13_5); |
|
|
|
|
vs23_5 = _mm512_fmadd_ps(w2, r0, vs23_5); |
|
|
|
|
} |
|
|
|
|
/*
|
|
|
|
|
* now fold the 512 bit accumulator vectors into 256 bit vectors so that the AVX2 code can finish |
|
|
|
|
* the tail of the vector |
|
|
|
|
*/ |
|
|
|
|
vs00 = _mm256_add_ps( _mm512_extractf32x8_ps(vs00_5, 0), _mm512_extractf32x8_ps(vs00_5, 1)); |
|
|
|
|
vs10 = _mm256_add_ps( _mm512_extractf32x8_ps(vs10_5, 0), _mm512_extractf32x8_ps(vs10_5, 1)); |
|
|
|
|
vs20 = _mm256_add_ps( _mm512_extractf32x8_ps(vs20_5, 0), _mm512_extractf32x8_ps(vs20_5, 1)); |
|
|
|
|
|
|
|
|
|
vs01 = _mm256_add_ps( _mm512_extractf32x8_ps(vs01_5, 0), _mm512_extractf32x8_ps(vs01_5, 1)); |
|
|
|
|
vs11 = _mm256_add_ps( _mm512_extractf32x8_ps(vs11_5, 0), _mm512_extractf32x8_ps(vs11_5, 1)); |
|
|
|
|
vs21 = _mm256_add_ps( _mm512_extractf32x8_ps(vs21_5, 0), _mm512_extractf32x8_ps(vs21_5, 1)); |
|
|
|
|
|
|
|
|
|
vs02 = _mm256_add_ps( _mm512_extractf32x8_ps(vs02_5, 0), _mm512_extractf32x8_ps(vs02_5, 1)); |
|
|
|
|
vs12 = _mm256_add_ps( _mm512_extractf32x8_ps(vs12_5, 0), _mm512_extractf32x8_ps(vs12_5, 1)); |
|
|
|
|
vs22 = _mm256_add_ps( _mm512_extractf32x8_ps(vs22_5, 0), _mm512_extractf32x8_ps(vs22_5, 1)); |
|
|
|
|
|
|
|
|
|
vs03 = _mm256_add_ps( _mm512_extractf32x8_ps(vs03_5, 0), _mm512_extractf32x8_ps(vs03_5, 1)); |
|
|
|
|
vs13 = _mm256_add_ps( _mm512_extractf32x8_ps(vs13_5, 0), _mm512_extractf32x8_ps(vs13_5, 1)); |
|
|
|
|
vs23 = _mm256_add_ps( _mm512_extractf32x8_ps(vs23_5, 0), _mm512_extractf32x8_ps(vs23_5, 1)); |
|
|
|
|
} |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
for (; k < vecsize; k += 8, rptr += 8 ) |
|
|
|
|
{ |
|
|
|
|
__m256 w0 = _mm256_load_ps(wptr0 + k); |
|
|
|
|
__m256 w1 = _mm256_load_ps(wptr1 + k); |
|
|
|
|
__m256 w2 = _mm256_load_ps(wptr2 + k); |
|
|
|
|
__m256 r0 = _mm256_load_ps(rptr); |
|
|
|
|
|
|
|
|
|
vs00 = _mm256_fmadd_ps(w0, r0, vs00); |
|
|
|
|
vs10 = _mm256_fmadd_ps(w1, r0, vs10); |
|
|
|
|
vs20 = _mm256_fmadd_ps(w2, r0, vs20); |
|
|
|
|
|
|
|
|
|
r0 = _mm256_load_ps(rptr + vecsize_aligned); |
|
|
|
|
vs01 = _mm256_fmadd_ps(w0, r0, vs01); |
|
|
|
|
vs11 = _mm256_fmadd_ps(w1, r0, vs11); |
|
|
|
|
vs21 = _mm256_fmadd_ps(w2, r0, vs21); |
|
|
|
|
|
|
|
|
|
r0 = _mm256_load_ps(rptr + vecsize_aligned*2); |
|
|
|
|
vs02 = _mm256_fmadd_ps(w0, r0, vs02); |
|
|
|
|
vs12 = _mm256_fmadd_ps(w1, r0, vs12); |
|
|
|
|
vs22 = _mm256_fmadd_ps(w2, r0, vs22); |
|
|
|
|
|
|
|
|
|
r0 = _mm256_load_ps(rptr + vecsize_aligned*3); |
|
|
|
|
vs03 = _mm256_fmadd_ps(w0, r0, vs03); |
|
|
|
|
vs13 = _mm256_fmadd_ps(w1, r0, vs13); |
|
|
|
|
vs23 = _mm256_fmadd_ps(w2, r0, vs23); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
__m256 t0 = _mm256_hadd_ps(_mm256_hadd_ps(vs00, vs01), _mm256_hadd_ps(vs02, vs03)); |
|
|
|
|
__m256 t1 = _mm256_hadd_ps(_mm256_hadd_ps(vs10, vs11), _mm256_hadd_ps(vs12, vs13)); |
|
|
|
|
__m256 t2 = _mm256_hadd_ps(_mm256_hadd_ps(vs20, vs21), _mm256_hadd_ps(vs22, vs23)); |
|
|
|
|
|
|
|
|
|
t0 = _mm256_add_ps(t0, _mm256_permute2f128_ps(t0, t0, 1)); |
|
|
|
|
t1 = _mm256_add_ps(t1, _mm256_permute2f128_ps(t1, t1, 1)); |
|
|
|
|
t2 = _mm256_add_ps(t2, _mm256_permute2f128_ps(t2, t2, 1)); |
|
|
|
|
|
|
|
|
|
__m128 s0, s1, s2; |
|
|
|
|
|
|
|
|
|
if( initOutput ) |
|
|
|
|
{ |
|
|
|
|
s0 = _mm_set1_ps(bias0); |
|
|
|
|
s1 = _mm_set1_ps(bias1); |
|
|
|
|
s2 = _mm_set1_ps(bias2); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
s0 = _mm_loadu_ps(outptr0 + j); |
|
|
|
|
s1 = _mm_loadu_ps(outptr1 + j); |
|
|
|
|
s2 = _mm_loadu_ps(outptr2 + j); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
s0 = _mm_add_ps(s0, _mm256_castps256_ps128(t0)); |
|
|
|
|
s1 = _mm_add_ps(s1, _mm256_castps256_ps128(t1)); |
|
|
|
|
s2 = _mm_add_ps(s2, _mm256_castps256_ps128(t2)); |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
__m128 m0 = _mm_cmp_ps(s0, z, _CMP_GT_OS); |
|
|
|
|
__m128 m1 = _mm_cmp_ps(s1, z, _CMP_GT_OS); |
|
|
|
|
__m128 m2 = _mm_cmp_ps(s2, z, _CMP_GT_OS); |
|
|
|
|
s0 = _mm_blendv_ps(_mm_mul_ps(s0, vr0), s0, m0); |
|
|
|
|
s1 = _mm_blendv_ps(_mm_mul_ps(s1, vr1), s1, m1); |
|
|
|
|
s2 = _mm_blendv_ps(_mm_mul_ps(s2, vr2), s2, m2); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( tail ) |
|
|
|
|
{ |
|
|
|
|
s0 = _mm_blendv_ps(_mm_loadu_ps(outptr0 + j), s0, mask); |
|
|
|
|
s1 = _mm_blendv_ps(_mm_loadu_ps(outptr1 + j), s1, mask); |
|
|
|
|
s2 = _mm_blendv_ps(_mm_loadu_ps(outptr2 + j), s2, mask); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
_mm_storeu_ps(outptr0 + j, s0); |
|
|
|
|
_mm_storeu_ps(outptr1 + j, s1); |
|
|
|
|
_mm_storeu_ps(outptr2 + j, s2); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( ; j <= blockSize - 2; j += 2 ) |
|
|
|
|
{ |
|
|
|
|
const float* rptr0 = rowbuf + j*vecsize_aligned; |
|
|
|
|
const float* rptr1 = rowbuf + (j+1)*vecsize_aligned; |
|
|
|
|
float s00, s01, s10, s11, s20, s21; |
|
|
|
|
|
|
|
|
|
if( initOutput ) |
|
|
|
|
{ |
|
|
|
|
s00 = s01 = bias0; |
|
|
|
|
s10 = s11 = bias1; |
|
|
|
|
s20 = s21 = bias2; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
s00 = outptr0[j]; s01 = outptr0[j+1]; |
|
|
|
|
s10 = outptr1[j]; s11 = outptr1[j+1]; |
|
|
|
|
s20 = outptr2[j]; s21 = outptr2[j+1]; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( int k = 0; k < vecsize; k++ ) |
|
|
|
|
{ |
|
|
|
|
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k]; |
|
|
|
|
float r = rptr0[k]; |
|
|
|
|
s00 += w0*r; s10 += w1*r; s20 += w2*r; |
|
|
|
|
r = rptr1[k]; |
|
|
|
|
s01 += w0*r; s11 += w1*r; s21 += w2*r; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
s00 = s00 > 0.f ? s00 : s00*r0; |
|
|
|
|
s01 = s01 > 0.f ? s01 : s01*r0; |
|
|
|
|
s10 = s10 > 0.f ? s10 : s10*r1; |
|
|
|
|
s11 = s11 > 0.f ? s11 : s11*r1; |
|
|
|
|
s20 = s20 > 0.f ? s20 : s20*r2; |
|
|
|
|
s21 = s21 > 0.f ? s21 : s21*r2; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
outptr0[j] = s00; |
|
|
|
|
outptr0[j+1] = s01; |
|
|
|
|
outptr1[j] = s10; |
|
|
|
|
outptr1[j+1] = s11; |
|
|
|
|
outptr2[j] = s20; |
|
|
|
|
outptr2[j+1] = s21; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( ; j < blockSize; j++ ) |
|
|
|
|
{ |
|
|
|
|
const float* rptr0 = rowbuf + j*vecsize_aligned; |
|
|
|
|
float s00, s10, s20; |
|
|
|
|
|
|
|
|
|
if( initOutput ) |
|
|
|
|
{ |
|
|
|
|
s00 = bias0; |
|
|
|
|
s10 = bias1; |
|
|
|
|
s20 = bias2; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
s00 = outptr0[j]; |
|
|
|
|
s10 = outptr1[j]; |
|
|
|
|
s20 = outptr2[j]; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( int k = 0; k < vecsize; k++ ) |
|
|
|
|
{ |
|
|
|
|
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k]; |
|
|
|
|
float r = rptr0[k]; |
|
|
|
|
s00 += w0*r; s10 += w1*r; s20 += w2*r; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
s00 = s00 > 0.f ? s00 : s00*r0; |
|
|
|
|
s10 = s10 > 0.f ? s10 : s10*r1; |
|
|
|
|
s20 = s20 > 0.f ? s20 : s20*r2; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
outptr0[j] = s00; |
|
|
|
|
outptr1[j] = s10; |
|
|
|
|
outptr2[j] = s20; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
_mm256_zeroupper(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static inline void _mm256_load_deinterleave(const float* ptr, __m256& a, __m256& b) |
|
|
|
|
{ |
|
|
|
|
__m256 t0 = _mm256_loadu_ps(ptr); |
|
|
|
@ -957,198 +654,6 @@ void fastGEMM1T( const float* vec, const float* weights, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
enum { FASCONV_BASE_VECSZ = 8 }; |
|
|
|
|
void fastConv( const float* weights, size_t wstep, const float* bias, |
|
|
|
|
const float* rowbuf, float* output, const int* outShape, |
|
|
|
|
int blockSize, int vecsize, int vecsize_aligned, |
|
|
|
|
const float* relu, bool initOutput ) |
|
|
|
|
{ |
|
|
|
|
const int vlm1 = vsetvlmax_e32m1(); |
|
|
|
|
int outCn = outShape[1]; |
|
|
|
|
size_t outPlaneSize = outShape[2]*outShape[3]; |
|
|
|
|
// now compute dot product of the weights
|
|
|
|
|
// and im2row-transformed part of the tensor
|
|
|
|
|
for( int i = 0; i < outCn; i += 3 ) |
|
|
|
|
{ |
|
|
|
|
int unroll_tail = FASCONV_BASE_VECSZ; |
|
|
|
|
const float* wptr0 = weights + i*wstep; |
|
|
|
|
const float* wptr1 = wptr0 + wstep; |
|
|
|
|
const float* wptr2 = wptr1 + wstep; |
|
|
|
|
float* outptr0 = output + i*outPlaneSize; |
|
|
|
|
float* outptr1 = outptr0 + outPlaneSize; |
|
|
|
|
float* outptr2 = outptr1 + outPlaneSize; |
|
|
|
|
float bias0 = bias[i], bias1 = bias[i+1], bias2 = bias[i+2]; |
|
|
|
|
|
|
|
|
|
if( i+2 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
wptr2 = wptr1; |
|
|
|
|
outptr2 = outptr1; |
|
|
|
|
bias2 = bias1; |
|
|
|
|
if( i+1 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
wptr2 = wptr1 = wptr0; |
|
|
|
|
outptr2 = outptr1 = outptr0; |
|
|
|
|
bias2 = bias1 = bias0; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int j = 0; |
|
|
|
|
for( ; j < blockSize; j += FASCONV_BASE_VECSZ ) |
|
|
|
|
{ |
|
|
|
|
const float* rptr = rowbuf + j*vecsize_aligned; |
|
|
|
|
const float *rptr1 = rptr + vecsize_aligned*1, |
|
|
|
|
*rptr2 = rptr + vecsize_aligned*2, |
|
|
|
|
*rptr3 = rptr + vecsize_aligned*3, |
|
|
|
|
*rptr4 = rptr + vecsize_aligned*4, |
|
|
|
|
*rptr5 = rptr + vecsize_aligned*5, |
|
|
|
|
*rptr6 = rptr + vecsize_aligned*6, |
|
|
|
|
*rptr7 = rptr + vecsize_aligned*7; |
|
|
|
|
if (j + FASCONV_BASE_VECSZ > blockSize) |
|
|
|
|
{ |
|
|
|
|
unroll_tail = blockSize - j; |
|
|
|
|
rptr1 = rptr + vecsize_aligned*std::min(1, unroll_tail-1), |
|
|
|
|
rptr2 = rptr + vecsize_aligned*std::min(2, unroll_tail-1), |
|
|
|
|
rptr3 = rptr + vecsize_aligned*std::min(3, unroll_tail-1), |
|
|
|
|
rptr4 = rptr + vecsize_aligned*std::min(4, unroll_tail-1), |
|
|
|
|
rptr5 = rptr + vecsize_aligned*std::min(5, unroll_tail-1), |
|
|
|
|
rptr6 = rptr + vecsize_aligned*std::min(6, unroll_tail-1), |
|
|
|
|
rptr7 = rptr + vecsize_aligned*std::min(7, unroll_tail-1); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int vl, avl = vecsize; |
|
|
|
|
vfloat32m1_t |
|
|
|
|
vs00 = vfmv_v_f_f32m1(0, vlm1), vs10 = vfmv_v_f_f32m1(0, vlm1), vs20 = vfmv_v_f_f32m1(0, vlm1), |
|
|
|
|
vs01 = vfmv_v_f_f32m1(0, vlm1), vs11 = vfmv_v_f_f32m1(0, vlm1), vs21 = vfmv_v_f_f32m1(0, vlm1), |
|
|
|
|
vs02 = vfmv_v_f_f32m1(0, vlm1), vs12 = vfmv_v_f_f32m1(0, vlm1), vs22 = vfmv_v_f_f32m1(0, vlm1), |
|
|
|
|
vs03 = vfmv_v_f_f32m1(0, vlm1), vs13 = vfmv_v_f_f32m1(0, vlm1), vs23 = vfmv_v_f_f32m1(0, vlm1), |
|
|
|
|
vs04 = vfmv_v_f_f32m1(0, vlm1), vs14 = vfmv_v_f_f32m1(0, vlm1), vs24 = vfmv_v_f_f32m1(0, vlm1), |
|
|
|
|
vs05 = vfmv_v_f_f32m1(0, vlm1), vs15 = vfmv_v_f_f32m1(0, vlm1), vs25 = vfmv_v_f_f32m1(0, vlm1), |
|
|
|
|
vs06 = vfmv_v_f_f32m1(0, vlm1), vs16 = vfmv_v_f_f32m1(0, vlm1), vs26 = vfmv_v_f_f32m1(0, vlm1), |
|
|
|
|
vs07 = vfmv_v_f_f32m1(0, vlm1), vs17 = vfmv_v_f_f32m1(0, vlm1), vs27 = vfmv_v_f_f32m1(0, vlm1); |
|
|
|
|
|
|
|
|
|
for (int k = 0; k < vecsize; k += vl, avl -= vl) |
|
|
|
|
{ |
|
|
|
|
vl = vsetvl_e32m1(avl); |
|
|
|
|
vfloat32m1_t w0 = vle32_v_f32m1(wptr0 + k, vl); |
|
|
|
|
vfloat32m1_t w1 = vle32_v_f32m1(wptr1 + k, vl); |
|
|
|
|
vfloat32m1_t w2 = vle32_v_f32m1(wptr2 + k, vl); |
|
|
|
|
vfloat32m1_t r0 = vle32_v_f32m1(rptr, vl); |
|
|
|
|
|
|
|
|
|
vs00 = vfmacc_vv_f32m1(vs00, w0, r0, vl); |
|
|
|
|
vs10 = vfmacc_vv_f32m1(vs10, w1, r0, vl); |
|
|
|
|
vs20 = vfmacc_vv_f32m1(vs20, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
r0 = vle32_v_f32m1(rptr1, vl); |
|
|
|
|
vs01 = vfmacc_vv_f32m1(vs01, w0, r0, vl); |
|
|
|
|
vs11 = vfmacc_vv_f32m1(vs11, w1, r0, vl); |
|
|
|
|
vs21 = vfmacc_vv_f32m1(vs21, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
r0 = vle32_v_f32m1(rptr2, vl); |
|
|
|
|
vs02 = vfmacc_vv_f32m1(vs02, w0, r0, vl); |
|
|
|
|
vs12 = vfmacc_vv_f32m1(vs12, w1, r0, vl); |
|
|
|
|
vs22 = vfmacc_vv_f32m1(vs22, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
r0 = vle32_v_f32m1(rptr3, vl); |
|
|
|
|
vs03 = vfmacc_vv_f32m1(vs03, w0, r0, vl); |
|
|
|
|
vs13 = vfmacc_vv_f32m1(vs13, w1, r0, vl); |
|
|
|
|
vs23 = vfmacc_vv_f32m1(vs23, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
r0 = vle32_v_f32m1(rptr4, vl); |
|
|
|
|
vs04 = vfmacc_vv_f32m1(vs04, w0, r0, vl); |
|
|
|
|
vs14 = vfmacc_vv_f32m1(vs14, w1, r0, vl); |
|
|
|
|
vs24 = vfmacc_vv_f32m1(vs24, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
r0 = vle32_v_f32m1(rptr5, vl); |
|
|
|
|
vs05 = vfmacc_vv_f32m1(vs05, w0, r0, vl); |
|
|
|
|
vs15 = vfmacc_vv_f32m1(vs15, w1, r0, vl); |
|
|
|
|
vs25 = vfmacc_vv_f32m1(vs25, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
r0 = vle32_v_f32m1(rptr6, vl); |
|
|
|
|
vs06 = vfmacc_vv_f32m1(vs06, w0, r0, vl); |
|
|
|
|
vs16 = vfmacc_vv_f32m1(vs16, w1, r0, vl); |
|
|
|
|
vs26 = vfmacc_vv_f32m1(vs26, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
r0 = vle32_v_f32m1(rptr7, vl); |
|
|
|
|
vs07 = vfmacc_vv_f32m1(vs07, w0, r0, vl); |
|
|
|
|
vs17 = vfmacc_vv_f32m1(vs17, w1, r0, vl); |
|
|
|
|
vs27 = vfmacc_vv_f32m1(vs27, w2, r0, vl); |
|
|
|
|
|
|
|
|
|
rptr += vl; rptr1 += vl; rptr2 += vl; rptr3 += vl; |
|
|
|
|
rptr4 += vl; rptr5 += vl; rptr6 += vl; rptr7 += vl; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// compute sum of each vs
|
|
|
|
|
vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm1); |
|
|
|
|
// unroll_tail(vl) is required here to be at least FASCONV_BASE_VECSZ, aka 8.
|
|
|
|
|
float sum0[FASCONV_BASE_VECSZ], sum1[FASCONV_BASE_VECSZ], sum2[FASCONV_BASE_VECSZ]; |
|
|
|
|
sum0[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs00, zero, vlm1)); |
|
|
|
|
sum0[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs01, zero, vlm1)); |
|
|
|
|
sum0[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs02, zero, vlm1)); |
|
|
|
|
sum0[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs03, zero, vlm1)); |
|
|
|
|
sum0[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs04, zero, vlm1)); |
|
|
|
|
sum0[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs05, zero, vlm1)); |
|
|
|
|
sum0[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs06, zero, vlm1)); |
|
|
|
|
sum0[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs07, zero, vlm1)); |
|
|
|
|
sum1[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs10, zero, vlm1)); |
|
|
|
|
sum1[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs11, zero, vlm1)); |
|
|
|
|
sum1[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs12, zero, vlm1)); |
|
|
|
|
sum1[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs13, zero, vlm1)); |
|
|
|
|
sum1[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs14, zero, vlm1)); |
|
|
|
|
sum1[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs15, zero, vlm1)); |
|
|
|
|
sum1[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs16, zero, vlm1)); |
|
|
|
|
sum1[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs17, zero, vlm1)); |
|
|
|
|
sum2[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs20, zero, vlm1)); |
|
|
|
|
sum2[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs21, zero, vlm1)); |
|
|
|
|
sum2[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs22, zero, vlm1)); |
|
|
|
|
sum2[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs23, zero, vlm1)); |
|
|
|
|
sum2[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs24, zero, vlm1)); |
|
|
|
|
sum2[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs25, zero, vlm1)); |
|
|
|
|
sum2[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs26, zero, vlm1)); |
|
|
|
|
sum2[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs27, zero, vlm1)); |
|
|
|
|
|
|
|
|
|
// if VLEN = 128, so LMUL = 2 for unroll_tail(vl) = 8.
|
|
|
|
|
// otherwise, VLEN >=256, we only use fist 8 element of the vReg.
|
|
|
|
|
vfloat32m2_t s0, s1, s2; |
|
|
|
|
if( initOutput ) |
|
|
|
|
{ |
|
|
|
|
s0 = vfmv_v_f_f32m2(bias0, unroll_tail); |
|
|
|
|
s1 = vfmv_v_f_f32m2(bias1, unroll_tail); |
|
|
|
|
s2 = vfmv_v_f_f32m2(bias2, unroll_tail); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
s0 = vle32_v_f32m2(outptr0 + j, unroll_tail); |
|
|
|
|
s1 = vle32_v_f32m2(outptr1 + j, unroll_tail); |
|
|
|
|
s2 = vle32_v_f32m2(outptr2 + j, unroll_tail); |
|
|
|
|
} |
|
|
|
|
s0 = vfadd_vv_f32m2(vle32_v_f32m2(sum0, unroll_tail), s0, unroll_tail); |
|
|
|
|
s1 = vfadd_vv_f32m2(vle32_v_f32m2(sum1, unroll_tail), s1, unroll_tail); |
|
|
|
|
s2 = vfadd_vv_f32m2(vle32_v_f32m2(sum2, unroll_tail), s2, unroll_tail); |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
float r0 = relu[i], r1 = relu[i+1], r2 = relu[i+2]; |
|
|
|
|
if( i+2 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
r2 = r1; |
|
|
|
|
if( i+1 >= outCn ) |
|
|
|
|
r2 = r1 = r0; |
|
|
|
|
} |
|
|
|
|
vbool16_t m0 = vmfgt_vf_f32m2_b16(s0, 0, unroll_tail); |
|
|
|
|
vbool16_t m1 = vmfgt_vf_f32m2_b16(s1, 0, unroll_tail); |
|
|
|
|
vbool16_t m2 = vmfgt_vf_f32m2_b16(s2, 0, unroll_tail); |
|
|
|
|
s0 = vmerge_vvm_f32m2(m0, vfmul_vf_f32m2(s0, r0, unroll_tail), s0, unroll_tail); |
|
|
|
|
s1 = vmerge_vvm_f32m2(m1, vfmul_vf_f32m2(s1, r1, unroll_tail), s1, unroll_tail); |
|
|
|
|
s2 = vmerge_vvm_f32m2(m2, vfmul_vf_f32m2(s2, r2, unroll_tail), s2, unroll_tail); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
vse32_v_f32m2(outptr0 + j, s0, unroll_tail); |
|
|
|
|
vse32_v_f32m2(outptr1 + j, s1, unroll_tail); |
|
|
|
|
vse32_v_f32m2(outptr2 + j, s2, unroll_tail); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
Example for load_deinterleave: |
|
|
|
|
input: ptr[16] = {1,2,3, ... ,14,15,16} |
|
|
|
@ -1345,317 +850,6 @@ void fastDepthwiseConv( const float* wptr, |
|
|
|
|
|
|
|
|
|
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_LASX |
|
|
|
|
|
|
|
|
|
enum { FASCONV_BASE_VECSZ = 4 }; |
|
|
|
|
|
|
|
|
|
void fastConv( const float* weights, size_t wstep, const float* bias, |
|
|
|
|
const float* rowbuf, float* output, const int* outShape, |
|
|
|
|
int blockSize, int vecsize, int vecsize_aligned, |
|
|
|
|
const float* relu, bool initOutput ) |
|
|
|
|
{ |
|
|
|
|
int outCn = outShape[1]; |
|
|
|
|
size_t outPlaneSize = outShape[2]*outShape[3]; |
|
|
|
|
float r0 = 1.f, r1 = 1.f, r2 = 1.f; |
|
|
|
|
__m256 t1 = _v256_setall_ps(1.f), t2 = _v256_setall_ps(0.f); |
|
|
|
|
__m128 vr0 = *(__m128*)&t1, vr1 = vr0, vr2 = vr0, z = *(__m128*)&t2; |
|
|
|
|
int CV_DECL_ALIGNED(16) maskbuf[FASCONV_BASE_VECSZ] = {0}; |
|
|
|
|
int rsz = blockSize % FASCONV_BASE_VECSZ; |
|
|
|
|
for( int i = 0; i < rsz; i++ ) |
|
|
|
|
maskbuf[FASCONV_BASE_VECSZ - i - 1] = -1; |
|
|
|
|
__m128i mask = __lsx_vld((const float*)maskbuf, 0); |
|
|
|
|
|
|
|
|
|
// now compute dot product of the weights
|
|
|
|
|
// and im2row-transformed part of the tensor
|
|
|
|
|
for( int i = 0; i < outCn; i += 3 ) |
|
|
|
|
{ |
|
|
|
|
const float* wptr0 = weights + i*wstep; |
|
|
|
|
const float* wptr1 = wptr0 + wstep; |
|
|
|
|
const float* wptr2 = wptr1 + wstep; |
|
|
|
|
float* outptr0 = output + i*outPlaneSize; |
|
|
|
|
float* outptr1 = outptr0 + outPlaneSize; |
|
|
|
|
float* outptr2 = outptr1 + outPlaneSize; |
|
|
|
|
float bias0 = bias[i], bias1 = bias[i+1], bias2 = bias[i+2]; |
|
|
|
|
|
|
|
|
|
if( i+2 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
wptr2 = wptr1; |
|
|
|
|
outptr2 = outptr1; |
|
|
|
|
bias2 = bias1; |
|
|
|
|
if( i+1 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
wptr2 = wptr1 = wptr0; |
|
|
|
|
outptr2 = outptr1 = outptr0; |
|
|
|
|
bias2 = bias1 = bias0; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
r0 = relu[i]; r1 = relu[i+1]; r2 = relu[i+2]; |
|
|
|
|
if( i+2 >= outCn ) |
|
|
|
|
{ |
|
|
|
|
r2 = r1; |
|
|
|
|
if( i+1 >= outCn ) |
|
|
|
|
r2 = r1 = r0; |
|
|
|
|
} |
|
|
|
|
vr0 = _v256_extract_low(_v256_setall_ps(r0)); |
|
|
|
|
vr1 = _v256_extract_low(_v256_setall_ps(r1)); |
|
|
|
|
vr2 = _v256_extract_low(_v256_setall_ps(r2)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int j = 0; |
|
|
|
|
for( ; j < blockSize; j += FASCONV_BASE_VECSZ ) |
|
|
|
|
{ |
|
|
|
|
bool tail = false; |
|
|
|
|
if (j + FASCONV_BASE_VECSZ > blockSize) |
|
|
|
|
{ |
|
|
|
|
if (j == 0) |
|
|
|
|
break; |
|
|
|
|
j = blockSize - FASCONV_BASE_VECSZ; |
|
|
|
|
tail = true; |
|
|
|
|
} |
|
|
|
|
int k = 0; |
|
|
|
|
const float* rptr = rowbuf + j*vecsize_aligned; |
|
|
|
|
|
|
|
|
|
__m256i tmp; |
|
|
|
|
__m256 vs00 = (__m256)__lasx_xvxor_v(tmp, tmp), vs01 = (__m256)__lasx_xvxor_v(tmp, tmp), |
|
|
|
|
vs02 = (__m256)__lasx_xvxor_v(tmp, tmp), vs03 = (__m256)__lasx_xvxor_v(tmp, tmp), |
|
|
|
|
vs10 = (__m256)__lasx_xvxor_v(tmp, tmp), vs11 = (__m256)__lasx_xvxor_v(tmp, tmp), |
|
|
|
|
vs12 = (__m256)__lasx_xvxor_v(tmp, tmp), vs13 = (__m256)__lasx_xvxor_v(tmp, tmp), |
|
|
|
|
vs20 = (__m256)__lasx_xvxor_v(tmp, tmp), vs21 = (__m256)__lasx_xvxor_v(tmp, tmp), |
|
|
|
|
vs22 = (__m256)__lasx_xvxor_v(tmp, tmp), vs23 = (__m256)__lasx_xvxor_v(tmp, tmp); |
|
|
|
|
|
|
|
|
|
for (; k < vecsize; k += 8, rptr += 8 ) |
|
|
|
|
{ |
|
|
|
|
__m256 w0 = (__m256)__lasx_xvld(wptr0 + k, 0); |
|
|
|
|
__m256 w1 = (__m256)__lasx_xvld(wptr1 + k, 0); |
|
|
|
|
__m256 w2 = (__m256)__lasx_xvld(wptr2 + k, 0); |
|
|
|
|
__m256 r0 = (__m256)__lasx_xvld(rptr, 0); |
|
|
|
|
|
|
|
|
|
vs00 = __lasx_xvfmadd_s(w0, r0, vs00); |
|
|
|
|
vs10 = __lasx_xvfmadd_s(w1, r0, vs10); |
|
|
|
|
vs20 = __lasx_xvfmadd_s(w2, r0, vs20); |
|
|
|
|
|
|
|
|
|
r0 = (__m256)__lasx_xvld(rptr + vecsize_aligned, 0); |
|
|
|
|
vs01 = __lasx_xvfmadd_s(w0, r0, vs01); |
|
|
|
|
vs11 = __lasx_xvfmadd_s(w1, r0, vs11); |
|
|
|
|
vs21 = __lasx_xvfmadd_s(w2, r0, vs21); |
|
|
|
|
|
|
|
|
|
r0 = (__m256)__lasx_xvld(rptr + vecsize_aligned*2, 0); |
|
|
|
|
vs02 = __lasx_xvfmadd_s(w0, r0, vs02); |
|
|
|
|
vs12 = __lasx_xvfmadd_s(w1, r0, vs12); |
|
|
|
|
vs22 = __lasx_xvfmadd_s(w2, r0, vs22); |
|
|
|
|
|
|
|
|
|
r0 = (__m256)__lasx_xvld(rptr + vecsize_aligned*3, 0); |
|
|
|
|
vs03 = __lasx_xvfmadd_s(w0, r0, vs03); |
|
|
|
|
vs13 = __lasx_xvfmadd_s(w1, r0, vs13); |
|
|
|
|
vs23 = __lasx_xvfmadd_s(w2, r0, vs23); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/*t0*/ |
|
|
|
|
__m256 vs00_perm = (__m256)__lasx_xvpermi_d(vs00, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs00_add_2w = __lasx_xvfadd_s(vs00, vs00_perm); |
|
|
|
|
__m256 tmp00_srl = (__m256)__lasx_xvsrli_d(vs00_add_2w, 32); |
|
|
|
|
__m256 vs00_add_4w = __lasx_xvfadd_s(vs00_add_2w, tmp00_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs01_perm = (__m256)__lasx_xvpermi_d(vs01, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs01_add_2w = __lasx_xvfadd_s(vs01, vs01_perm); |
|
|
|
|
__m256 tmp01_srl = (__m256)__lasx_xvsrli_d(vs01_add_2w, 32); |
|
|
|
|
__m256 vs01_add_4w = __lasx_xvfadd_s(vs01_add_2w, tmp01_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs02_perm = (__m256)__lasx_xvpermi_d(vs02, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs02_add_2w = __lasx_xvfadd_s(vs02, vs02_perm); |
|
|
|
|
__m256 tmp02_srl = (__m256)__lasx_xvsrli_d(vs02_add_2w, 32); |
|
|
|
|
__m256 vs02_add_4w = __lasx_xvfadd_s(vs02_add_2w, tmp02_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs03_perm = (__m256)__lasx_xvpermi_d(vs03, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs03_add_2w = __lasx_xvfadd_s(vs03, vs03_perm); |
|
|
|
|
__m256 tmp03_srl = (__m256)__lasx_xvsrli_d(vs03_add_2w, 32); |
|
|
|
|
__m256 vs03_add_4w = __lasx_xvfadd_s(vs03_add_2w, tmp03_srl); |
|
|
|
|
|
|
|
|
|
__m256i vs01_vs00 = __lasx_xvpackev_w((__m256i)vs01_add_4w, (__m256i)vs00_add_4w); |
|
|
|
|
__m256i vs03_vs02 = __lasx_xvpackev_w((__m256i)vs03_add_4w, (__m256i)vs02_add_4w); |
|
|
|
|
__m256 t0 = (__m256)__lasx_xvpackev_d(vs03_vs02, vs01_vs00); |
|
|
|
|
|
|
|
|
|
/*t1*/ |
|
|
|
|
__m256 vs10_perm = (__m256)__lasx_xvpermi_d(vs10, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs10_add_2w = __lasx_xvfadd_s(vs10, vs10_perm); |
|
|
|
|
__m256 tmp10_srl = (__m256)__lasx_xvsrli_d(vs10_add_2w, 32); |
|
|
|
|
__m256 vs10_add_4w = __lasx_xvfadd_s(vs10_add_2w, tmp10_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs11_perm = (__m256)__lasx_xvpermi_d(vs11, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs11_add_2w = __lasx_xvfadd_s(vs11, vs11_perm); |
|
|
|
|
__m256 tmp11_srl = (__m256)__lasx_xvsrli_d(vs11_add_2w, 32); |
|
|
|
|
__m256 vs11_add_4w = __lasx_xvfadd_s(vs11_add_2w, tmp11_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs12_perm = (__m256)__lasx_xvpermi_d(vs12, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs12_add_2w = __lasx_xvfadd_s(vs12, vs12_perm); |
|
|
|
|
__m256 tmp12_srl = (__m256)__lasx_xvsrli_d(vs12_add_2w, 32); |
|
|
|
|
__m256 vs12_add_4w = __lasx_xvfadd_s(vs12_add_2w, tmp12_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs13_perm = (__m256)__lasx_xvpermi_d(vs13, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs13_add_2w = __lasx_xvfadd_s(vs13, vs13_perm); |
|
|
|
|
__m256 tmp13_srl = (__m256)__lasx_xvsrli_d(vs13_add_2w, 32); |
|
|
|
|
__m256 vs13_add_4w = __lasx_xvfadd_s(vs13_add_2w, tmp13_srl); |
|
|
|
|
|
|
|
|
|
__m256i vs11_vs10 = __lasx_xvpackev_w((__m256i)vs11_add_4w, (__m256i)vs10_add_4w); |
|
|
|
|
__m256i vs13_vs12 = __lasx_xvpackev_w((__m256i)vs13_add_4w, (__m256i)vs12_add_4w); |
|
|
|
|
__m256 t1 = (__m256)__lasx_xvpackev_d(vs13_vs12, vs11_vs10); |
|
|
|
|
|
|
|
|
|
/*t2*/ |
|
|
|
|
__m256 vs20_perm = (__m256)__lasx_xvpermi_d(vs20, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs20_add_2w = __lasx_xvfadd_s(vs20, vs20_perm); |
|
|
|
|
__m256 tmp20_srl = (__m256)__lasx_xvsrli_d(vs20_add_2w, 32); |
|
|
|
|
__m256 vs20_add_4w = __lasx_xvfadd_s(vs20_add_2w, tmp20_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs21_perm = (__m256)__lasx_xvpermi_d(vs21, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs21_add_2w = __lasx_xvfadd_s(vs21, vs21_perm); |
|
|
|
|
__m256 tmp21_srl = (__m256)__lasx_xvsrli_d(vs21_add_2w, 32); |
|
|
|
|
__m256 vs21_add_4w = __lasx_xvfadd_s(vs21_add_2w, tmp21_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs22_perm = (__m256)__lasx_xvpermi_d(vs22, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs22_add_2w = __lasx_xvfadd_s(vs22, vs22_perm); |
|
|
|
|
__m256 tmp22_srl = (__m256)__lasx_xvsrli_d(vs22_add_2w, 32); |
|
|
|
|
__m256 vs22_add_4w = __lasx_xvfadd_s(vs22_add_2w, tmp22_srl); |
|
|
|
|
|
|
|
|
|
__m256 vs23_perm = (__m256)__lasx_xvpermi_d(vs23, (2<<6) + (3<<4) + (0<<2) + 1); |
|
|
|
|
__m256 vs23_add_2w = __lasx_xvfadd_s(vs23, vs23_perm); |
|
|
|
|
__m256 tmp23_srl = (__m256)__lasx_xvsrli_d(vs23_add_2w, 32); |
|
|
|
|
__m256 vs23_add_4w = __lasx_xvfadd_s(vs23_add_2w, tmp23_srl); |
|
|
|
|
|
|
|
|
|
__m256i vs21_vs20 = __lasx_xvpackev_w((__m256i)vs21_add_4w, (__m256i)vs20_add_4w); |
|
|
|
|
__m256i vs23_vs22 = __lasx_xvpackev_w((__m256i)vs23_add_4w, (__m256i)vs22_add_4w); |
|
|
|
|
__m256 t2 = (__m256)__lasx_xvpackev_d(vs23_vs22, vs21_vs20); |
|
|
|
|
|
|
|
|
|
t0 = __lasx_xvfadd_s(t0, (__m256)__lasx_xvpermi_q(t0, t0, 1)); |
|
|
|
|
t1 = __lasx_xvfadd_s(t1, (__m256)__lasx_xvpermi_q(t1, t1, 1)); |
|
|
|
|
t2 = __lasx_xvfadd_s(t2, (__m256)__lasx_xvpermi_q(t2, t2, 1)); |
|
|
|
|
|
|
|
|
|
__m128 s0, s1, s2; |
|
|
|
|
|
|
|
|
|
if( initOutput ) |
|
|
|
|
{ |
|
|
|
|
s0 = _v256_extract_low(_v256_setall_ps(bias0)); |
|
|
|
|
s1 = _v256_extract_low(_v256_setall_ps(bias1)); |
|
|
|
|
s2 = _v256_extract_low(_v256_setall_ps(bias2)); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
s0 = (__m128)__lsx_vld(outptr0 + j, 0); |
|
|
|
|
s1 = (__m128)__lsx_vld(outptr1 + j, 0); |
|
|
|
|
s2 = (__m128)__lsx_vld(outptr2 + j, 0); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
s0 = __lsx_vfadd_s(s0, *(__m128*)&t0); |
|
|
|
|
s1 = __lsx_vfadd_s(s1, *(__m128*)&t1); |
|
|
|
|
s2 = __lsx_vfadd_s(s2, *(__m128*)&t2); |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
__m128i m0 = __lsx_vfcmp_clt_s(z, s0); |
|
|
|
|
__m128i m1 = __lsx_vfcmp_clt_s(z, s1); |
|
|
|
|
__m128i m2 = __lsx_vfcmp_clt_s(z, s2); |
|
|
|
|
s0 = (__m128)__lsx_vbitsel_v((__m128i)__lsx_vfmul_s(s0, vr0), (__m128i)s0, m0); |
|
|
|
|
s1 = (__m128)__lsx_vbitsel_v((__m128i)__lsx_vfmul_s(s1, vr1), (__m128i)s1, m1); |
|
|
|
|
s2 = (__m128)__lsx_vbitsel_v((__m128i)__lsx_vfmul_s(s2, vr2), (__m128i)s2, m2); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( tail ) |
|
|
|
|
{ |
|
|
|
|
s0 = (__m128)__lsx_vbitsel_v(__lsx_vld(outptr0 + j, 0), (__m128i)s0, mask); |
|
|
|
|
s1 = (__m128)__lsx_vbitsel_v(__lsx_vld(outptr1 + j, 0), (__m128i)s1, mask); |
|
|
|
|
s2 = (__m128)__lsx_vbitsel_v(__lsx_vld(outptr2 + j, 0), (__m128i)s2, mask); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
__lsx_vst(s0, outptr0 + j, 0); |
|
|
|
|
__lsx_vst(s1, outptr1 + j, 0); |
|
|
|
|
__lsx_vst(s2, outptr2 + j, 0); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( ; j <= blockSize - 2; j += 2 ) |
|
|
|
|
{ |
|
|
|
|
const float* rptr0 = rowbuf + j*vecsize_aligned; |
|
|
|
|
const float* rptr1 = rowbuf + (j+1)*vecsize_aligned; |
|
|
|
|
float s00, s01, s10, s11, s20, s21; |
|
|
|
|
|
|
|
|
|
if( initOutput ) |
|
|
|
|
{ |
|
|
|
|
s00 = s01 = bias0; |
|
|
|
|
s10 = s11 = bias1; |
|
|
|
|
s20 = s21 = bias2; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
s00 = outptr0[j]; s01 = outptr0[j+1]; |
|
|
|
|
s10 = outptr1[j]; s11 = outptr1[j+1]; |
|
|
|
|
s20 = outptr2[j]; s21 = outptr2[j+1]; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( int k = 0; k < vecsize; k++ ) |
|
|
|
|
{ |
|
|
|
|
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k]; |
|
|
|
|
float r = rptr0[k]; |
|
|
|
|
s00 += w0*r; s10 += w1*r; s20 += w2*r; |
|
|
|
|
r = rptr1[k]; |
|
|
|
|
s01 += w0*r; s11 += w1*r; s21 += w2*r; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
s00 = s00 > 0.f ? s00 : s00*r0; |
|
|
|
|
s01 = s01 > 0.f ? s01 : s01*r0; |
|
|
|
|
s10 = s10 > 0.f ? s10 : s10*r1; |
|
|
|
|
s11 = s11 > 0.f ? s11 : s11*r1; |
|
|
|
|
s20 = s20 > 0.f ? s20 : s20*r2; |
|
|
|
|
s21 = s21 > 0.f ? s21 : s21*r2; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
outptr0[j] = s00; |
|
|
|
|
outptr0[j+1] = s01; |
|
|
|
|
outptr1[j] = s10; |
|
|
|
|
outptr1[j+1] = s11; |
|
|
|
|
outptr2[j] = s20; |
|
|
|
|
outptr2[j+1] = s21; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( ; j < blockSize; j++ ) |
|
|
|
|
{ |
|
|
|
|
const float* rptr0 = rowbuf + j*vecsize_aligned; |
|
|
|
|
float s00, s10, s20; |
|
|
|
|
|
|
|
|
|
if( initOutput ) |
|
|
|
|
{ |
|
|
|
|
s00 = bias0; |
|
|
|
|
s10 = bias1; |
|
|
|
|
s20 = bias2; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
s00 = outptr0[j]; |
|
|
|
|
s10 = outptr1[j]; |
|
|
|
|
s20 = outptr2[j]; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for( int k = 0; k < vecsize; k++ ) |
|
|
|
|
{ |
|
|
|
|
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k]; |
|
|
|
|
float r = rptr0[k]; |
|
|
|
|
s00 += w0*r; s10 += w1*r; s20 += w2*r; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if( relu ) |
|
|
|
|
{ |
|
|
|
|
s00 = s00 > 0.f ? s00 : s00*r0; |
|
|
|
|
s10 = s10 > 0.f ? s10 : s10*r1; |
|
|
|
|
s20 = s20 > 0.f ? s20 : s20*r2; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
outptr0[j] = s00; |
|
|
|
|
outptr1[j] = s10; |
|
|
|
|
outptr2[j] = s20; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static inline void _v256_load_deinterleave(const float* ptr, __m256& a, __m256& b) |
|
|
|
|
{ |
|
|
|
|
__m256 t0 = (__m256)__lasx_xvld(ptr, 0); |
|
|
|
|