Merge pull request #21086 from hanliutong:rvv-dnn

Further optimize DNN for RISC-V Vector.

* Optimize DNN on RVV by using vsetvl.

* Rename vl.

* Update fastConv by using setvl instead of mask.

* Fix fastDepthwiseConv
pull/21238/head
HAN Liutong 3 years ago committed by GitHub
parent e3e04f5dae
commit 1599f9f0c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 412
      modules/dnn/src/layers/layers_common.simd.hpp

@ -782,15 +782,10 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr,
size_t bstep, float* cptr, size_t cstep, size_t bstep, float* cptr, size_t cstep,
int ma, int na, int nb ) int ma, int na, int nb )
{ {
int n = 0; int avl = nb, vl;
int vl = vsetvlmax_e32m4(); for(int n = 0; n < nb; n += vl, avl -= vl)
int mvl = vl;
for( ; n < nb; n += vl )
{ {
if ( n + vl > nb) { vl = vsetvl_e32m4(avl);
mvl = nb - n;
}
for( int m = 0; m < ma; m += 7 ) for( int m = 0; m < ma; m += 7 )
{ {
const float* aptr0 = aptr + astep*m; const float* aptr0 = aptr + astep*m;
@ -827,22 +822,22 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr,
float a5 = aptr5[k]; float a5 = aptr5[k];
float a6 = aptr6[k]; float a6 = aptr6[k];
vfloat32m4_t b = vle32_v_f32m4(bptr + k*bstep + n, mvl); vfloat32m4_t b = vle32_v_f32m4(bptr + k*bstep + n, vl);
d0 = vfmacc_vf_f32m4(d0, a0, b, mvl); d0 = vfmacc_vf_f32m4(d0, a0, b, vl);
d1 = vfmacc_vf_f32m4(d1, a1, b, mvl); d1 = vfmacc_vf_f32m4(d1, a1, b, vl);
d2 = vfmacc_vf_f32m4(d2, a2, b, mvl); d2 = vfmacc_vf_f32m4(d2, a2, b, vl);
d3 = vfmacc_vf_f32m4(d3, a3, b, mvl); d3 = vfmacc_vf_f32m4(d3, a3, b, vl);
d4 = vfmacc_vf_f32m4(d4, a4, b, mvl); d4 = vfmacc_vf_f32m4(d4, a4, b, vl);
d5 = vfmacc_vf_f32m4(d5, a5, b, mvl); d5 = vfmacc_vf_f32m4(d5, a5, b, vl);
d6 = vfmacc_vf_f32m4(d6, a6, b, mvl); d6 = vfmacc_vf_f32m4(d6, a6, b, vl);
} }
vse32_v_f32m4(cptr0 + n, d0, mvl); vse32_v_f32m4(cptr0 + n, d0, vl);
vse32_v_f32m4(cptr1 + n, d1, mvl); vse32_v_f32m4(cptr1 + n, d1, vl);
vse32_v_f32m4(cptr2 + n, d2, mvl); vse32_v_f32m4(cptr2 + n, d2, vl);
vse32_v_f32m4(cptr3 + n, d3, mvl); vse32_v_f32m4(cptr3 + n, d3, vl);
vse32_v_f32m4(cptr4 + n, d4, mvl); vse32_v_f32m4(cptr4 + n, d4, vl);
vse32_v_f32m4(cptr5 + n, d5, mvl); vse32_v_f32m4(cptr5 + n, d5, vl);
vse32_v_f32m4(cptr6 + n, d6, mvl); vse32_v_f32m4(cptr6 + n, d6, vl);
} }
} }
} }
@ -851,7 +846,7 @@ void fastGEMM1T( const float* vec, const float* weights,
size_t wstep, const float* bias, size_t wstep, const float* bias,
float* dst, int nvecs, int vecsize ) float* dst, int nvecs, int vecsize )
{ {
int vlm2 = vsetvlmax_e32m2(); const int vlm2 = vsetvlmax_e32m2();
int i = 0; int i = 0;
for( ; i <= nvecs - 15; i += 15 ) for( ; i <= nvecs - 15; i += 15 )
{ {
@ -862,45 +857,26 @@ void fastGEMM1T( const float* vec, const float* weights,
vs6 = vfmv_v_f_f32m2(0, vlm2), vs7 = vfmv_v_f_f32m2(0, vlm2), vs8 = vfmv_v_f_f32m2(0, vlm2), vs6 = vfmv_v_f_f32m2(0, vlm2), vs7 = vfmv_v_f_f32m2(0, vlm2), vs8 = vfmv_v_f_f32m2(0, vlm2),
vs9 = vfmv_v_f_f32m2(0, vlm2), vs10 = vfmv_v_f_f32m2(0, vlm2), vs11 = vfmv_v_f_f32m2(0, vlm2), vs9 = vfmv_v_f_f32m2(0, vlm2), vs10 = vfmv_v_f_f32m2(0, vlm2), vs11 = vfmv_v_f_f32m2(0, vlm2),
vs12 = vfmv_v_f_f32m2(0, vlm2), vs13 = vfmv_v_f_f32m2(0, vlm2), vs14 = vfmv_v_f_f32m2(0, vlm2); vs12 = vfmv_v_f_f32m2(0, vlm2), vs13 = vfmv_v_f_f32m2(0, vlm2), vs14 = vfmv_v_f_f32m2(0, vlm2);
int k = 0; int avl = vecsize, vl;
for( ; k < vecsize - vlm2; k += vlm2, wptr += vlm2 ) for(int k = 0 ; k < vecsize; k += vl, wptr += vl, avl -= vl)
{ {
vfloat32m2_t v = vle32_v_f32m2(vec + k, vlm2); vl = vsetvl_e32m2(avl);
vfloat32m2_t v = vle32_v_f32m2(vec + k, vl);
vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vlm2), v, vlm2); vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vl), v, vl);
vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep, vlm2), v, vlm2); vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep, vl), v, vl);
vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*2, vlm2), v, vlm2); vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*2, vl), v, vl);
vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*3, vlm2), v, vlm2); vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*3, vl), v, vl);
vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*4, vlm2), v, vlm2); vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*4, vl), v, vl);
vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*5, vlm2), v, vlm2); vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*5, vl), v, vl);
vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*6, vlm2), v, vlm2); vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*6, vl), v, vl);
vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*7, vlm2), v, vlm2); vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*7, vl), v, vl);
vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*8, vlm2), v, vlm2); vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*8, vl), v, vl);
vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*9, vlm2), v, vlm2); vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*9, vl), v, vl);
vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*10, vlm2), v, vlm2); vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*10, vl), v, vl);
vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*11, vlm2), v, vlm2); vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*11, vl), v, vl);
vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*12, vlm2), v, vlm2); vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*12, vl), v, vl);
vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*13, vlm2), v, vlm2); vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*13, vl), v, vl);
vs14 = vfmacc_vv_f32m2(vs14, vle32_v_f32m2(wptr + wstep*14, vlm2), v, vlm2); vs14 = vfmacc_vv_f32m2(vs14, vle32_v_f32m2(wptr + wstep*14, vl), v, vl);
}
int kvl = vecsize - k;
if (kvl > 0) {
vfloat32m2_t v = vle32_v_f32m2(vec + k, kvl);
vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, kvl), v, kvl);
vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep*1, kvl), v, kvl);
vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*2, kvl), v, kvl);
vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*3, kvl), v, kvl);
vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*4, kvl), v, kvl);
vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*5, kvl), v, kvl);
vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*6, kvl), v, kvl);
vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*7, kvl), v, kvl);
vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*8, kvl), v, kvl);
vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*9, kvl), v, kvl);
vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*10, kvl), v, kvl);
vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*11, kvl), v, kvl);
vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*12, kvl), v, kvl);
vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*13, kvl), v, kvl);
vs14 = vfmacc_vv_f32m2(vs14, vle32_v_f32m2(wptr + wstep*14, kvl), v, kvl);
} }
// Calculate the sum of each vector // Calculate the sum of each vector
@ -925,8 +901,8 @@ void fastGEMM1T( const float* vec, const float* weights,
vfloat32m4_t s0 = vfadd_vv_f32m4(vle32_v_f32m4(sum, 15), vle32_v_f32m4(bias + i, 15), 15); vfloat32m4_t s0 = vfadd_vv_f32m4(vle32_v_f32m4(sum, 15), vle32_v_f32m4(bias + i, 15), 15);
vse32_v_f32m4(dst + i, s0, 15); vse32_v_f32m4(dst + i, s0, 15);
} }
int mvl = nvecs - i; int unroll_tail = nvecs - i;
if (mvl > 0) if (unroll_tail > 0)
{ {
const float* wptr = weights + i*wstep; const float* wptr = weights + i*wstep;
vfloat32m2_t vfloat32m2_t
@ -935,43 +911,27 @@ void fastGEMM1T( const float* vec, const float* weights,
vs6 = vfmv_v_f_f32m2(0, vlm2), vs7 = vfmv_v_f_f32m2(0, vlm2), vs8 = vfmv_v_f_f32m2(0, vlm2), vs6 = vfmv_v_f_f32m2(0, vlm2), vs7 = vfmv_v_f_f32m2(0, vlm2), vs8 = vfmv_v_f_f32m2(0, vlm2),
vs9 = vfmv_v_f_f32m2(0, vlm2), vs10 = vfmv_v_f_f32m2(0, vlm2), vs11 = vfmv_v_f_f32m2(0, vlm2), vs9 = vfmv_v_f_f32m2(0, vlm2), vs10 = vfmv_v_f_f32m2(0, vlm2), vs11 = vfmv_v_f_f32m2(0, vlm2),
vs12 = vfmv_v_f_f32m2(0, vlm2), vs13 = vfmv_v_f_f32m2(0, vlm2); vs12 = vfmv_v_f_f32m2(0, vlm2), vs13 = vfmv_v_f_f32m2(0, vlm2);
int k = 0; int avl = vecsize, vl;
for( ; k <= vecsize - vlm2; k += vlm2, wptr += vlm2 ) for(int k = 0; k < vecsize; k += vl, wptr += vl, avl -= vl)
{ {
vfloat32m2_t v = vle32_v_f32m2(vec + k, vlm2); vl = vsetvl_e32m2(avl);
vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vlm2), v, vlm2); vfloat32m2_t v = vle32_v_f32m2(vec + k, vl);
vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep*std::min(1, mvl-1), vlm2), v, vlm2); vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vl), v, vl);
vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*std::min(2, mvl-1), vlm2), v, vlm2); vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep*std::min(1, unroll_tail-1), vl), v, vl);
vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*std::min(3, mvl-1), vlm2), v, vlm2); vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*std::min(2, unroll_tail-1), vl), v, vl);
vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*std::min(4, mvl-1), vlm2), v, vlm2); vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*std::min(3, unroll_tail-1), vl), v, vl);
vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*std::min(5, mvl-1), vlm2), v, vlm2); vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*std::min(4, unroll_tail-1), vl), v, vl);
vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*std::min(6, mvl-1), vlm2), v, vlm2); vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*std::min(5, unroll_tail-1), vl), v, vl);
vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*std::min(7, mvl-1), vlm2), v, vlm2); vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*std::min(6, unroll_tail-1), vl), v, vl);
vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*std::min(8, mvl-1), vlm2), v, vlm2); vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*std::min(7, unroll_tail-1), vl), v, vl);
vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*std::min(9, mvl-1), vlm2), v, vlm2); vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*std::min(8, unroll_tail-1), vl), v, vl);
vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*std::min(10, mvl-1), vlm2), v, vlm2); vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*std::min(9, unroll_tail-1), vl), v, vl);
vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*std::min(11, mvl-1), vlm2), v, vlm2); vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*std::min(10, unroll_tail-1), vl), v, vl);
vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*std::min(12, mvl-1), vlm2), v, vlm2); vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*std::min(11, unroll_tail-1), vl), v, vl);
vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*std::min(13, mvl-1), vlm2), v, vlm2); vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*std::min(12, unroll_tail-1), vl), v, vl);
} vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*std::min(13, unroll_tail-1), vl), v, vl);
int kvl = vecsize - k;
if (kvl > 0) {
vfloat32m2_t v = vle32_v_f32m2(vec + k, kvl);
vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, kvl), v, kvl);
vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep*std::min(1, mvl-1), kvl), v, kvl);
vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*std::min(2, mvl-1), kvl), v, kvl);
vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*std::min(3, mvl-1), kvl), v, kvl);
vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*std::min(4, mvl-1), kvl), v, kvl);
vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*std::min(5, mvl-1), kvl), v, kvl);
vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*std::min(6, mvl-1), kvl), v, kvl);
vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*std::min(7, mvl-1), kvl), v, kvl);
vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*std::min(8, mvl-1), kvl), v, kvl);
vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*std::min(9, mvl-1), kvl), v, kvl);
vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*std::min(10, mvl-1), kvl), v, kvl);
vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*std::min(11, mvl-1), kvl), v, kvl);
vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*std::min(12, mvl-1), kvl), v, kvl);
vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*std::min(13, mvl-1), kvl), v, kvl);
} }
// Calculate the sum of each vector // Calculate the sum of each vector
float sum[14]; float sum[14];
vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm2); vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm2);
@ -990,8 +950,8 @@ void fastGEMM1T( const float* vec, const float* weights,
sum[12] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m2_f32m1(zero, vs12, zero, vlm2)); sum[12] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m2_f32m1(zero, vs12, zero, vlm2));
sum[13] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m2_f32m1(zero, vs13, zero, vlm2)); sum[13] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m2_f32m1(zero, vs13, zero, vlm2));
vfloat32m4_t s0 = vfadd_vv_f32m4(vle32_v_f32m4(sum, mvl), vle32_v_f32m4(bias + i, mvl), mvl); vfloat32m4_t s0 = vfadd_vv_f32m4(vle32_v_f32m4(sum, unroll_tail), vle32_v_f32m4(bias + i, unroll_tail), unroll_tail);
vse32_v_f32m4(dst + i, s0, mvl); vse32_v_f32m4(dst + i, s0, unroll_tail);
} }
} }
@ -1001,14 +961,14 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
int blockSize, int vecsize, int vecsize_aligned, int blockSize, int vecsize, int vecsize_aligned,
const float* relu, bool initOutput ) const float* relu, bool initOutput )
{ {
int vl = FASCONV_BASE_VECSZ; const int vlm1 = vsetvlmax_e32m1();
int vlm1Max = vsetvlmax_e32m1();
int outCn = outShape[1]; int outCn = outShape[1];
size_t outPlaneSize = outShape[2]*outShape[3]; size_t outPlaneSize = outShape[2]*outShape[3];
// now compute dot product of the weights // now compute dot product of the weights
// and im2row-transformed part of the tensor // and im2row-transformed part of the tensor
for( int i = 0; i < outCn; i += 3 ) for( int i = 0; i < outCn; i += 3 )
{ {
int unroll_tail = FASCONV_BASE_VECSZ;
const float* wptr0 = weights + i*wstep; const float* wptr0 = weights + i*wstep;
const float* wptr1 = wptr0 + wstep; const float* wptr1 = wptr0 + wstep;
const float* wptr2 = wptr1 + wstep; const float* wptr2 = wptr1 + wstep;
@ -1033,20 +993,27 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
int j = 0; int j = 0;
for( ; j < blockSize; j += FASCONV_BASE_VECSZ ) for( ; j < blockSize; j += FASCONV_BASE_VECSZ )
{ {
bool tail = false; const float* rptr = rowbuf + j*vecsize_aligned;
const float *rptr1 = rptr + vecsize_aligned*1,
*rptr2 = rptr + vecsize_aligned*2,
*rptr3 = rptr + vecsize_aligned*3,
*rptr4 = rptr + vecsize_aligned*4,
*rptr5 = rptr + vecsize_aligned*5,
*rptr6 = rptr + vecsize_aligned*6,
*rptr7 = rptr + vecsize_aligned*7;
if (j + FASCONV_BASE_VECSZ > blockSize) if (j + FASCONV_BASE_VECSZ > blockSize)
{ {
if (j == 0) { unroll_tail = blockSize - j;
vl = blockSize; rptr1 = rptr + vecsize_aligned*std::min(1, unroll_tail-1),
} rptr2 = rptr + vecsize_aligned*std::min(2, unroll_tail-1),
else { rptr3 = rptr + vecsize_aligned*std::min(3, unroll_tail-1),
j = blockSize - FASCONV_BASE_VECSZ; rptr4 = rptr + vecsize_aligned*std::min(4, unroll_tail-1),
tail = true; rptr5 = rptr + vecsize_aligned*std::min(5, unroll_tail-1),
} rptr6 = rptr + vecsize_aligned*std::min(6, unroll_tail-1),
rptr7 = rptr + vecsize_aligned*std::min(7, unroll_tail-1);
} }
int k = 0;
const float* rptr = rowbuf + j*vecsize_aligned; int vl, avl = vecsize;
int vlm1 = vsetvlmax_e32m1();
vfloat32m1_t vfloat32m1_t
vs00 = vfmv_v_f_f32m1(0, vlm1), vs10 = vfmv_v_f_f32m1(0, vlm1), vs20 = vfmv_v_f_f32m1(0, vlm1), vs00 = vfmv_v_f_f32m1(0, vlm1), vs10 = vfmv_v_f_f32m1(0, vlm1), vs20 = vfmv_v_f_f32m1(0, vlm1),
vs01 = vfmv_v_f_f32m1(0, vlm1), vs11 = vfmv_v_f_f32m1(0, vlm1), vs21 = vfmv_v_f_f32m1(0, vlm1), vs01 = vfmv_v_f_f32m1(0, vlm1), vs11 = vfmv_v_f_f32m1(0, vlm1), vs21 = vfmv_v_f_f32m1(0, vlm1),
@ -1057,107 +1024,107 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
vs06 = vfmv_v_f_f32m1(0, vlm1), vs16 = vfmv_v_f_f32m1(0, vlm1), vs26 = vfmv_v_f_f32m1(0, vlm1), vs06 = vfmv_v_f_f32m1(0, vlm1), vs16 = vfmv_v_f_f32m1(0, vlm1), vs26 = vfmv_v_f_f32m1(0, vlm1),
vs07 = vfmv_v_f_f32m1(0, vlm1), vs17 = vfmv_v_f_f32m1(0, vlm1), vs27 = vfmv_v_f_f32m1(0, vlm1); vs07 = vfmv_v_f_f32m1(0, vlm1), vs17 = vfmv_v_f_f32m1(0, vlm1), vs27 = vfmv_v_f_f32m1(0, vlm1);
for (; k < vecsize; k += vlm1, rptr += vlm1 ) for (int k = 0; k < vecsize; k += vl, avl -= vl)
{ {
if (k + vlm1 >= vecsize) { vl = vsetvl_e32m1(avl);
vlm1 = vecsize - k; vfloat32m1_t w0 = vle32_v_f32m1(wptr0 + k, vl);
} vfloat32m1_t w1 = vle32_v_f32m1(wptr1 + k, vl);
vfloat32m1_t w0 = vle32_v_f32m1(wptr0 + k, vlm1); vfloat32m1_t w2 = vle32_v_f32m1(wptr2 + k, vl);
vfloat32m1_t w1 = vle32_v_f32m1(wptr1 + k, vlm1); vfloat32m1_t r0 = vle32_v_f32m1(rptr, vl);
vfloat32m1_t w2 = vle32_v_f32m1(wptr2 + k, vlm1);
vfloat32m1_t r0 = vle32_v_f32m1(rptr, vlm1); vs00 = vfmacc_vv_f32m1(vs00, w0, r0, vl);
vs10 = vfmacc_vv_f32m1(vs10, w1, r0, vl);
vs00 = vfmacc_vv_f32m1(vs00, w0, r0, vlm1); vs20 = vfmacc_vv_f32m1(vs20, w2, r0, vl);
vs10 = vfmacc_vv_f32m1(vs10, w1, r0, vlm1);
vs20 = vfmacc_vv_f32m1(vs20, w2, r0, vlm1); r0 = vle32_v_f32m1(rptr1, vl);
vs01 = vfmacc_vv_f32m1(vs01, w0, r0, vl);
r0 = vle32_v_f32m1(rptr + vecsize_aligned, vlm1); vs11 = vfmacc_vv_f32m1(vs11, w1, r0, vl);
vs01 = vfmacc_vv_f32m1(vs01, w0, r0, vlm1); vs21 = vfmacc_vv_f32m1(vs21, w2, r0, vl);
vs11 = vfmacc_vv_f32m1(vs11, w1, r0, vlm1);
vs21 = vfmacc_vv_f32m1(vs21, w2, r0, vlm1); r0 = vle32_v_f32m1(rptr2, vl);
vs02 = vfmacc_vv_f32m1(vs02, w0, r0, vl);
r0 = vle32_v_f32m1(rptr + vecsize_aligned*2, vlm1); vs12 = vfmacc_vv_f32m1(vs12, w1, r0, vl);
vs02 = vfmacc_vv_f32m1(vs02, w0, r0, vlm1); vs22 = vfmacc_vv_f32m1(vs22, w2, r0, vl);
vs12 = vfmacc_vv_f32m1(vs12, w1, r0, vlm1);
vs22 = vfmacc_vv_f32m1(vs22, w2, r0, vlm1); r0 = vle32_v_f32m1(rptr3, vl);
vs03 = vfmacc_vv_f32m1(vs03, w0, r0, vl);
r0 = vle32_v_f32m1(rptr + vecsize_aligned*3, vlm1); vs13 = vfmacc_vv_f32m1(vs13, w1, r0, vl);
vs03 = vfmacc_vv_f32m1(vs03, w0, r0, vlm1); vs23 = vfmacc_vv_f32m1(vs23, w2, r0, vl);
vs13 = vfmacc_vv_f32m1(vs13, w1, r0, vlm1);
vs23 = vfmacc_vv_f32m1(vs23, w2, r0, vlm1); r0 = vle32_v_f32m1(rptr4, vl);
vs04 = vfmacc_vv_f32m1(vs04, w0, r0, vl);
r0 = vle32_v_f32m1(rptr + vecsize_aligned*4, vlm1); vs14 = vfmacc_vv_f32m1(vs14, w1, r0, vl);
vs04 = vfmacc_vv_f32m1(vs04, w0, r0, vlm1); vs24 = vfmacc_vv_f32m1(vs24, w2, r0, vl);
vs14 = vfmacc_vv_f32m1(vs14, w1, r0, vlm1);
vs24 = vfmacc_vv_f32m1(vs24, w2, r0, vlm1); r0 = vle32_v_f32m1(rptr5, vl);
vs05 = vfmacc_vv_f32m1(vs05, w0, r0, vl);
r0 = vle32_v_f32m1(rptr + vecsize_aligned*5, vlm1); vs15 = vfmacc_vv_f32m1(vs15, w1, r0, vl);
vs05 = vfmacc_vv_f32m1(vs05, w0, r0, vlm1); vs25 = vfmacc_vv_f32m1(vs25, w2, r0, vl);
vs15 = vfmacc_vv_f32m1(vs15, w1, r0, vlm1);
vs25 = vfmacc_vv_f32m1(vs25, w2, r0, vlm1); r0 = vle32_v_f32m1(rptr6, vl);
vs06 = vfmacc_vv_f32m1(vs06, w0, r0, vl);
r0 = vle32_v_f32m1(rptr + vecsize_aligned*6, vlm1); vs16 = vfmacc_vv_f32m1(vs16, w1, r0, vl);
vs06 = vfmacc_vv_f32m1(vs06, w0, r0, vlm1); vs26 = vfmacc_vv_f32m1(vs26, w2, r0, vl);
vs16 = vfmacc_vv_f32m1(vs16, w1, r0, vlm1);
vs26 = vfmacc_vv_f32m1(vs26, w2, r0, vlm1); r0 = vle32_v_f32m1(rptr7, vl);
vs07 = vfmacc_vv_f32m1(vs07, w0, r0, vl);
r0 = vle32_v_f32m1(rptr + vecsize_aligned*7, vlm1); vs17 = vfmacc_vv_f32m1(vs17, w1, r0, vl);
vs07 = vfmacc_vv_f32m1(vs07, w0, r0, vlm1); vs27 = vfmacc_vv_f32m1(vs27, w2, r0, vl);
vs17 = vfmacc_vv_f32m1(vs17, w1, r0, vlm1);
vs27 = vfmacc_vv_f32m1(vs27, w2, r0, vlm1); rptr += vl; rptr1 += vl; rptr2 += vl; rptr3 += vl;
rptr4 += vl; rptr5 += vl; rptr6 += vl; rptr7 += vl;
} }
// compute sum of each vs // compute sum of each vs
vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm1Max); vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm1);
// vl is required here to be at least FASCONV_BASE_VECSZ, aka 8. // unroll_tail(vl) is required here to be at least FASCONV_BASE_VECSZ, aka 8.
float sum0[FASCONV_BASE_VECSZ], sum1[FASCONV_BASE_VECSZ], sum2[FASCONV_BASE_VECSZ]; float sum0[FASCONV_BASE_VECSZ], sum1[FASCONV_BASE_VECSZ], sum2[FASCONV_BASE_VECSZ];
sum0[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs00, zero, vlm1Max)); sum0[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs00, zero, vlm1));
sum0[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs01, zero, vlm1Max)); sum0[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs01, zero, vlm1));
sum0[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs02, zero, vlm1Max)); sum0[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs02, zero, vlm1));
sum0[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs03, zero, vlm1Max)); sum0[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs03, zero, vlm1));
sum0[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs04, zero, vlm1Max)); sum0[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs04, zero, vlm1));
sum0[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs05, zero, vlm1Max)); sum0[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs05, zero, vlm1));
sum0[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs06, zero, vlm1Max)); sum0[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs06, zero, vlm1));
sum0[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs07, zero, vlm1Max)); sum0[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs07, zero, vlm1));
sum1[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs10, zero, vlm1Max)); sum1[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs10, zero, vlm1));
sum1[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs11, zero, vlm1Max)); sum1[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs11, zero, vlm1));
sum1[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs12, zero, vlm1Max)); sum1[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs12, zero, vlm1));
sum1[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs13, zero, vlm1Max)); sum1[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs13, zero, vlm1));
sum1[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs14, zero, vlm1Max)); sum1[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs14, zero, vlm1));
sum1[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs15, zero, vlm1Max)); sum1[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs15, zero, vlm1));
sum1[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs16, zero, vlm1Max)); sum1[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs16, zero, vlm1));
sum1[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs17, zero, vlm1Max)); sum1[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs17, zero, vlm1));
sum2[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs20, zero, vlm1Max)); sum2[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs20, zero, vlm1));
sum2[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs21, zero, vlm1Max)); sum2[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs21, zero, vlm1));
sum2[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs22, zero, vlm1Max)); sum2[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs22, zero, vlm1));
sum2[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs23, zero, vlm1Max)); sum2[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs23, zero, vlm1));
sum2[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs24, zero, vlm1Max)); sum2[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs24, zero, vlm1));
sum2[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs25, zero, vlm1Max)); sum2[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs25, zero, vlm1));
sum2[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs26, zero, vlm1Max)); sum2[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs26, zero, vlm1));
sum2[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs27, zero, vlm1Max)); sum2[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs27, zero, vlm1));
// if VLEN = 128, so LMUL = 2 for vl = 8. // if VLEN = 128, so LMUL = 2 for unroll_tail(vl) = 8.
// otherwise, VLEN >=256, we only use fist 8 element of the vReg. // otherwise, VLEN >=256, we only use fist 8 element of the vReg.
vfloat32m2_t s0, s1, s2; vfloat32m2_t s0, s1, s2;
if( initOutput ) if( initOutput )
{ {
s0 = vfmv_v_f_f32m2(bias0, vl); s0 = vfmv_v_f_f32m2(bias0, unroll_tail);
s1 = vfmv_v_f_f32m2(bias1, vl); s1 = vfmv_v_f_f32m2(bias1, unroll_tail);
s2 = vfmv_v_f_f32m2(bias2, vl); s2 = vfmv_v_f_f32m2(bias2, unroll_tail);
} }
else else
{ {
s0 = vle32_v_f32m2(outptr0 + j, vl); s0 = vle32_v_f32m2(outptr0 + j, unroll_tail);
s1 = vle32_v_f32m2(outptr1 + j, vl); s1 = vle32_v_f32m2(outptr1 + j, unroll_tail);
s2 = vle32_v_f32m2(outptr2 + j, vl); s2 = vle32_v_f32m2(outptr2 + j, unroll_tail);
} }
s0 = vfadd_vv_f32m2(vle32_v_f32m2(sum0, vl), s0, vl); s0 = vfadd_vv_f32m2(vle32_v_f32m2(sum0, unroll_tail), s0, unroll_tail);
s1 = vfadd_vv_f32m2(vle32_v_f32m2(sum1, vl), s1, vl); s1 = vfadd_vv_f32m2(vle32_v_f32m2(sum1, unroll_tail), s1, unroll_tail);
s2 = vfadd_vv_f32m2(vle32_v_f32m2(sum2, vl), s2, vl); s2 = vfadd_vv_f32m2(vle32_v_f32m2(sum2, unroll_tail), s2, unroll_tail);
if( relu ) if( relu )
{ {
vfloat32m2_t vr0 = vfmv_v_f_f32m2(1, vl), vr1 = vfmv_v_f_f32m2(1, vl), vr2 = vfmv_v_f_f32m2(1, vl);
float r0 = relu[i], r1 = relu[i+1], r2 = relu[i+2]; float r0 = relu[i], r1 = relu[i+1], r2 = relu[i+2];
if( i+2 >= outCn ) if( i+2 >= outCn )
{ {
@ -1165,33 +1132,17 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
if( i+1 >= outCn ) if( i+1 >= outCn )
r2 = r1 = r0; r2 = r1 = r0;
} }
vr0 = vfmv_v_f_f32m2(r0, vl); vbool16_t m0 = vmfgt_vf_f32m2_b16(s0, 0, unroll_tail);
vr1 = vfmv_v_f_f32m2(r1, vl); vbool16_t m1 = vmfgt_vf_f32m2_b16(s1, 0, unroll_tail);
vr2 = vfmv_v_f_f32m2(r2, vl); vbool16_t m2 = vmfgt_vf_f32m2_b16(s2, 0, unroll_tail);
vbool16_t m0 = vmfgt_vf_f32m2_b16(s0, 0, vl); s0 = vmerge_vvm_f32m2(m0, vfmul_vf_f32m2(s0, r0, unroll_tail), s0, unroll_tail);
vbool16_t m1 = vmfgt_vf_f32m2_b16(s1, 0, vl); s1 = vmerge_vvm_f32m2(m1, vfmul_vf_f32m2(s1, r1, unroll_tail), s1, unroll_tail);
vbool16_t m2 = vmfgt_vf_f32m2_b16(s2, 0, vl); s2 = vmerge_vvm_f32m2(m2, vfmul_vf_f32m2(s2, r2, unroll_tail), s2, unroll_tail);
s0 = vmerge_vvm_f32m2(m0, vfmul_vv_f32m2(s0, vr0, vl), s0, vl);
s1 = vmerge_vvm_f32m2(m1, vfmul_vv_f32m2(s1, vr1, vl), s1, vl);
s2 = vmerge_vvm_f32m2(m2, vfmul_vv_f32m2(s2, vr2, vl), s2, vl);
}
if( tail )
{
int maskbuf[FASCONV_BASE_VECSZ] = {0};
int rsz = blockSize % FASCONV_BASE_VECSZ;
for( int i = 0; i < rsz; i++ )
maskbuf[FASCONV_BASE_VECSZ - i - 1] = -1;
vint32m2_t vmaskbuf = vle32_v_i32m2(maskbuf ,vl);
vbool16_t mask = vmslt_vx_i32m2_b16(vmaskbuf, 0, vl); // mask for tail
s0 = vmerge_vvm_f32m2(mask, vle32_v_f32m2(outptr0 + j, vl), s0, vl);
s1 = vmerge_vvm_f32m2(mask, vle32_v_f32m2(outptr1 + j, vl), s1, vl);
s2 = vmerge_vvm_f32m2(mask, vle32_v_f32m2(outptr2 + j, vl), s2, vl);
} }
vse32_v_f32m2(outptr0 + j, s0, vl); vse32_v_f32m2(outptr0 + j, s0, unroll_tail);
vse32_v_f32m2(outptr1 + j, s1, vl); vse32_v_f32m2(outptr1 + j, s1, unroll_tail);
vse32_v_f32m2(outptr2 + j, s2, vl); vse32_v_f32m2(outptr2 + j, s2, unroll_tail);
} }
} }
} }
@ -1236,7 +1187,7 @@ void fastDepthwiseConv( const float* wptr,
float* outptr_, float* outptr_,
int out_d, int outH, int outW ) int out_d, int outH, int outW )
{ {
int vl = vsetvlmax_e32m2(); int vl;
const float w00_ = wptr[0], w01_ = wptr[1], w02_ = wptr[2], const float w00_ = wptr[0], w01_ = wptr[1], w02_ = wptr[2],
w10 = wptr[3], w11 = wptr[4], w12 = wptr[5], w10 = wptr[3], w11 = wptr[4], w12 = wptr[5],
w20_ = wptr[6], w21_ = wptr[7], w22_ = wptr[8]; w20_ = wptr[6], w21_ = wptr[7], w22_ = wptr[8];
@ -1275,11 +1226,11 @@ void fastDepthwiseConv( const float* wptr,
if (stride_w == 1 || (stride_w == 2 && dilation_w == 1)) if (stride_w == 1 || (stride_w == 2 && dilation_w == 1))
{ {
int avl = outW1 - out_j;
if( stride_w == 1 ) if( stride_w == 1 )
for( ; out_j < outW1; out_j += vl ) for( ; out_j < outW1; out_j += vl, avl -= vl)
{ {
if (out_j + vl > outW1) vl = vsetvl_e32m2(avl);
vl = outW1 - out_j;
int in_j = out_j * stride_w - pad_l; int in_j = out_j * stride_w - pad_l;
vfloat32m2_t v00 = vle32_v_f32m2(imgptr0 + in_j, vl), vfloat32m2_t v00 = vle32_v_f32m2(imgptr0 + in_j, vl),
v01 = vle32_v_f32m2(imgptr0 + in_j + dilation_w, vl), v01 = vle32_v_f32m2(imgptr0 + in_j + dilation_w, vl),
@ -1313,10 +1264,9 @@ void fastDepthwiseConv( const float* wptr,
vse32_v_f32m2(outptr + out_j, vout0, vl); vse32_v_f32m2(outptr + out_j, vout0, vl);
} }
else //stride_w == 2 && dilation_w == 1 else //stride_w == 2 && dilation_w == 1
for( ; out_j < outW1; out_j += vl ) for( ; out_j < outW1; out_j += vl, avl -= vl)
{ {
if (out_j + vl > outW1) vl = vsetvl_e32m2(avl);
vl = outW1 - out_j;
int in_j = out_j * stride_w - pad_l; int in_j = out_j * stride_w - pad_l;
vfloat32m2_t v00, v01, v02, v10, v11, v12, v20, v21, v22, unused; vfloat32m2_t v00, v01, v02, v10, v11, v12, v20, v21, v22, unused;
vfloat32m2_load_deinterleave(imgptr0 + in_j, v00, v01, vl); vfloat32m2_load_deinterleave(imgptr0 + in_j, v00, v01, vl);

Loading…
Cancel
Save