@ -782,15 +782,10 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr,
size_t bstep , float * cptr , size_t cstep ,
size_t bstep , float * cptr , size_t cstep ,
int ma , int na , int nb )
int ma , int na , int nb )
{
{
int n = 0 ;
int avl = nb , vl ;
int vl = vsetvlmax_e32m4 ( ) ;
for ( int n = 0 ; n < nb ; n + = vl , avl - = vl )
int mvl = vl ;
for ( ; n < nb ; n + = vl )
{
{
if ( n + vl > nb ) {
vl = vsetvl_e32m4 ( avl ) ;
mvl = nb - n ;
}
for ( int m = 0 ; m < ma ; m + = 7 )
for ( int m = 0 ; m < ma ; m + = 7 )
{
{
const float * aptr0 = aptr + astep * m ;
const float * aptr0 = aptr + astep * m ;
@ -827,22 +822,22 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr,
float a5 = aptr5 [ k ] ;
float a5 = aptr5 [ k ] ;
float a6 = aptr6 [ k ] ;
float a6 = aptr6 [ k ] ;
vfloat32m4_t b = vle32_v_f32m4 ( bptr + k * bstep + n , m vl) ;
vfloat32m4_t b = vle32_v_f32m4 ( bptr + k * bstep + n , vl ) ;
d0 = vfmacc_vf_f32m4 ( d0 , a0 , b , m vl) ;
d0 = vfmacc_vf_f32m4 ( d0 , a0 , b , vl ) ;
d1 = vfmacc_vf_f32m4 ( d1 , a1 , b , m vl) ;
d1 = vfmacc_vf_f32m4 ( d1 , a1 , b , vl ) ;
d2 = vfmacc_vf_f32m4 ( d2 , a2 , b , m vl) ;
d2 = vfmacc_vf_f32m4 ( d2 , a2 , b , vl ) ;
d3 = vfmacc_vf_f32m4 ( d3 , a3 , b , m vl) ;
d3 = vfmacc_vf_f32m4 ( d3 , a3 , b , vl ) ;
d4 = vfmacc_vf_f32m4 ( d4 , a4 , b , m vl) ;
d4 = vfmacc_vf_f32m4 ( d4 , a4 , b , vl ) ;
d5 = vfmacc_vf_f32m4 ( d5 , a5 , b , m vl) ;
d5 = vfmacc_vf_f32m4 ( d5 , a5 , b , vl ) ;
d6 = vfmacc_vf_f32m4 ( d6 , a6 , b , m vl) ;
d6 = vfmacc_vf_f32m4 ( d6 , a6 , b , vl ) ;
}
}
vse32_v_f32m4 ( cptr0 + n , d0 , m vl) ;
vse32_v_f32m4 ( cptr0 + n , d0 , vl ) ;
vse32_v_f32m4 ( cptr1 + n , d1 , m vl) ;
vse32_v_f32m4 ( cptr1 + n , d1 , vl ) ;
vse32_v_f32m4 ( cptr2 + n , d2 , m vl) ;
vse32_v_f32m4 ( cptr2 + n , d2 , vl ) ;
vse32_v_f32m4 ( cptr3 + n , d3 , m vl) ;
vse32_v_f32m4 ( cptr3 + n , d3 , vl ) ;
vse32_v_f32m4 ( cptr4 + n , d4 , m vl) ;
vse32_v_f32m4 ( cptr4 + n , d4 , vl ) ;
vse32_v_f32m4 ( cptr5 + n , d5 , m vl) ;
vse32_v_f32m4 ( cptr5 + n , d5 , vl ) ;
vse32_v_f32m4 ( cptr6 + n , d6 , m vl) ;
vse32_v_f32m4 ( cptr6 + n , d6 , vl ) ;
}
}
}
}
}
}
@ -851,7 +846,7 @@ void fastGEMM1T( const float* vec, const float* weights,
size_t wstep , const float * bias ,
size_t wstep , const float * bias ,
float * dst , int nvecs , int vecsize )
float * dst , int nvecs , int vecsize )
{
{
int vlm2 = vsetvlmax_e32m2 ( ) ;
const int vlm2 = vsetvlmax_e32m2 ( ) ;
int i = 0 ;
int i = 0 ;
for ( ; i < = nvecs - 15 ; i + = 15 )
for ( ; i < = nvecs - 15 ; i + = 15 )
{
{
@ -862,45 +857,26 @@ void fastGEMM1T( const float* vec, const float* weights,
vs6 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs7 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs8 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs6 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs7 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs8 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs9 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs10 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs11 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs9 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs10 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs11 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs12 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs13 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs14 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ;
vs12 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs13 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs14 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ;
int k = 0 ;
int avl = vecsize , vl ;
for ( ; k < vecsize - vlm2 ; k + = vlm2 , wptr + = vlm2 )
for ( int k = 0 ; k < vecsize ; k + = vl , wptr + = vl , avl - = vl )
{
{
vfloat32m2_t v = vle32_v_f32m2 ( vec + k , vlm2 ) ;
vl = vsetvl_e32m2 ( avl ) ;
vfloat32m2_t v = vle32_v_f32m2 ( vec + k , vl ) ;
vs0 = vfmacc_vv_f32m2 ( vs0 , vle32_v_f32m2 ( wptr , vlm2 ) , v , vlm2 ) ;
vs0 = vfmacc_vv_f32m2 ( vs0 , vle32_v_f32m2 ( wptr , vl ) , v , vl ) ;
vs1 = vfmacc_vv_f32m2 ( vs1 , vle32_v_f32m2 ( wptr + wstep , vlm2 ) , v , vlm2 ) ;
vs1 = vfmacc_vv_f32m2 ( vs1 , vle32_v_f32m2 ( wptr + wstep , vl ) , v , vl ) ;
vs2 = vfmacc_vv_f32m2 ( vs2 , vle32_v_f32m2 ( wptr + wstep * 2 , vlm2 ) , v , vlm2 ) ;
vs2 = vfmacc_vv_f32m2 ( vs2 , vle32_v_f32m2 ( wptr + wstep * 2 , vl ) , v , vl ) ;
vs3 = vfmacc_vv_f32m2 ( vs3 , vle32_v_f32m2 ( wptr + wstep * 3 , vlm2 ) , v , vlm2 ) ;
vs3 = vfmacc_vv_f32m2 ( vs3 , vle32_v_f32m2 ( wptr + wstep * 3 , vl ) , v , vl ) ;
vs4 = vfmacc_vv_f32m2 ( vs4 , vle32_v_f32m2 ( wptr + wstep * 4 , vlm2 ) , v , vlm2 ) ;
vs4 = vfmacc_vv_f32m2 ( vs4 , vle32_v_f32m2 ( wptr + wstep * 4 , vl ) , v , vl ) ;
vs5 = vfmacc_vv_f32m2 ( vs5 , vle32_v_f32m2 ( wptr + wstep * 5 , vlm2 ) , v , vlm2 ) ;
vs5 = vfmacc_vv_f32m2 ( vs5 , vle32_v_f32m2 ( wptr + wstep * 5 , vl ) , v , vl ) ;
vs6 = vfmacc_vv_f32m2 ( vs6 , vle32_v_f32m2 ( wptr + wstep * 6 , vlm2 ) , v , vlm2 ) ;
vs6 = vfmacc_vv_f32m2 ( vs6 , vle32_v_f32m2 ( wptr + wstep * 6 , vl ) , v , vl ) ;
vs7 = vfmacc_vv_f32m2 ( vs7 , vle32_v_f32m2 ( wptr + wstep * 7 , vlm2 ) , v , vlm2 ) ;
vs7 = vfmacc_vv_f32m2 ( vs7 , vle32_v_f32m2 ( wptr + wstep * 7 , vl ) , v , vl ) ;
vs8 = vfmacc_vv_f32m2 ( vs8 , vle32_v_f32m2 ( wptr + wstep * 8 , vlm2 ) , v , vlm2 ) ;
vs8 = vfmacc_vv_f32m2 ( vs8 , vle32_v_f32m2 ( wptr + wstep * 8 , vl ) , v , vl ) ;
vs9 = vfmacc_vv_f32m2 ( vs9 , vle32_v_f32m2 ( wptr + wstep * 9 , vlm2 ) , v , vlm2 ) ;
vs9 = vfmacc_vv_f32m2 ( vs9 , vle32_v_f32m2 ( wptr + wstep * 9 , vl ) , v , vl ) ;
vs10 = vfmacc_vv_f32m2 ( vs10 , vle32_v_f32m2 ( wptr + wstep * 10 , vlm2 ) , v , vlm2 ) ;
vs10 = vfmacc_vv_f32m2 ( vs10 , vle32_v_f32m2 ( wptr + wstep * 10 , vl ) , v , vl ) ;
vs11 = vfmacc_vv_f32m2 ( vs11 , vle32_v_f32m2 ( wptr + wstep * 11 , vlm2 ) , v , vlm2 ) ;
vs11 = vfmacc_vv_f32m2 ( vs11 , vle32_v_f32m2 ( wptr + wstep * 11 , vl ) , v , vl ) ;
vs12 = vfmacc_vv_f32m2 ( vs12 , vle32_v_f32m2 ( wptr + wstep * 12 , vlm2 ) , v , vlm2 ) ;
vs12 = vfmacc_vv_f32m2 ( vs12 , vle32_v_f32m2 ( wptr + wstep * 12 , vl ) , v , vl ) ;
vs13 = vfmacc_vv_f32m2 ( vs13 , vle32_v_f32m2 ( wptr + wstep * 13 , vlm2 ) , v , vlm2 ) ;
vs13 = vfmacc_vv_f32m2 ( vs13 , vle32_v_f32m2 ( wptr + wstep * 13 , vl ) , v , vl ) ;
vs14 = vfmacc_vv_f32m2 ( vs14 , vle32_v_f32m2 ( wptr + wstep * 14 , vlm2 ) , v , vlm2 ) ;
vs14 = vfmacc_vv_f32m2 ( vs14 , vle32_v_f32m2 ( wptr + wstep * 14 , vl ) , v , vl ) ;
}
int kvl = vecsize - k ;
if ( kvl > 0 ) {
vfloat32m2_t v = vle32_v_f32m2 ( vec + k , kvl ) ;
vs0 = vfmacc_vv_f32m2 ( vs0 , vle32_v_f32m2 ( wptr , kvl ) , v , kvl ) ;
vs1 = vfmacc_vv_f32m2 ( vs1 , vle32_v_f32m2 ( wptr + wstep * 1 , kvl ) , v , kvl ) ;
vs2 = vfmacc_vv_f32m2 ( vs2 , vle32_v_f32m2 ( wptr + wstep * 2 , kvl ) , v , kvl ) ;
vs3 = vfmacc_vv_f32m2 ( vs3 , vle32_v_f32m2 ( wptr + wstep * 3 , kvl ) , v , kvl ) ;
vs4 = vfmacc_vv_f32m2 ( vs4 , vle32_v_f32m2 ( wptr + wstep * 4 , kvl ) , v , kvl ) ;
vs5 = vfmacc_vv_f32m2 ( vs5 , vle32_v_f32m2 ( wptr + wstep * 5 , kvl ) , v , kvl ) ;
vs6 = vfmacc_vv_f32m2 ( vs6 , vle32_v_f32m2 ( wptr + wstep * 6 , kvl ) , v , kvl ) ;
vs7 = vfmacc_vv_f32m2 ( vs7 , vle32_v_f32m2 ( wptr + wstep * 7 , kvl ) , v , kvl ) ;
vs8 = vfmacc_vv_f32m2 ( vs8 , vle32_v_f32m2 ( wptr + wstep * 8 , kvl ) , v , kvl ) ;
vs9 = vfmacc_vv_f32m2 ( vs9 , vle32_v_f32m2 ( wptr + wstep * 9 , kvl ) , v , kvl ) ;
vs10 = vfmacc_vv_f32m2 ( vs10 , vle32_v_f32m2 ( wptr + wstep * 10 , kvl ) , v , kvl ) ;
vs11 = vfmacc_vv_f32m2 ( vs11 , vle32_v_f32m2 ( wptr + wstep * 11 , kvl ) , v , kvl ) ;
vs12 = vfmacc_vv_f32m2 ( vs12 , vle32_v_f32m2 ( wptr + wstep * 12 , kvl ) , v , kvl ) ;
vs13 = vfmacc_vv_f32m2 ( vs13 , vle32_v_f32m2 ( wptr + wstep * 13 , kvl ) , v , kvl ) ;
vs14 = vfmacc_vv_f32m2 ( vs14 , vle32_v_f32m2 ( wptr + wstep * 14 , kvl ) , v , kvl ) ;
}
}
// Calculate the sum of each vector
// Calculate the sum of each vector
@ -925,8 +901,8 @@ void fastGEMM1T( const float* vec, const float* weights,
vfloat32m4_t s0 = vfadd_vv_f32m4 ( vle32_v_f32m4 ( sum , 15 ) , vle32_v_f32m4 ( bias + i , 15 ) , 15 ) ;
vfloat32m4_t s0 = vfadd_vv_f32m4 ( vle32_v_f32m4 ( sum , 15 ) , vle32_v_f32m4 ( bias + i , 15 ) , 15 ) ;
vse32_v_f32m4 ( dst + i , s0 , 15 ) ;
vse32_v_f32m4 ( dst + i , s0 , 15 ) ;
}
}
int mv l = nvecs - i ;
int unroll_tai l = nvecs - i ;
if ( mv l > 0 )
if ( unroll_tai l > 0 )
{
{
const float * wptr = weights + i * wstep ;
const float * wptr = weights + i * wstep ;
vfloat32m2_t
vfloat32m2_t
@ -935,43 +911,27 @@ void fastGEMM1T( const float* vec, const float* weights,
vs6 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs7 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs8 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs6 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs7 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs8 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs9 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs10 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs11 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs9 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs10 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs11 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ,
vs12 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs13 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ;
vs12 = vfmv_v_f_f32m2 ( 0 , vlm2 ) , vs13 = vfmv_v_f_f32m2 ( 0 , vlm2 ) ;
int k = 0 ;
int avl = vecsize , vl ;
for ( ; k < = vecsize - vlm2 ; k + = vlm2 , wptr + = vlm2 )
for ( int k = 0 ; k < vecsize ; k + = vl , wptr + = vl , avl - = vl )
{
{
vfloat32m2_t v = vle32_v_f32m2 ( vec + k , vlm2 ) ;
vl = vsetvl_e32m2 ( avl ) ;
vs0 = vfmacc_vv_f32m2 ( vs0 , vle32_v_f32m2 ( wptr , vlm2 ) , v , vlm2 ) ;
vfloat32m2_t v = vle32_v_f32m2 ( vec + k , vl ) ;
vs1 = vfmacc_vv_f32m2 ( vs1 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 1 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs0 = vfmacc_vv_f32m2 ( vs0 , vle32_v_f32m2 ( wptr , vl ) , v , vl ) ;
vs2 = vfmacc_vv_f32m2 ( vs2 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 2 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs1 = vfmacc_vv_f32m2 ( vs1 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 1 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs3 = vfmacc_vv_f32m2 ( vs3 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 3 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs2 = vfmacc_vv_f32m2 ( vs2 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 2 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs4 = vfmacc_vv_f32m2 ( vs4 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 4 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs3 = vfmacc_vv_f32m2 ( vs3 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 3 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs5 = vfmacc_vv_f32m2 ( vs5 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 5 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs4 = vfmacc_vv_f32m2 ( vs4 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 4 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs6 = vfmacc_vv_f32m2 ( vs6 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 6 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs5 = vfmacc_vv_f32m2 ( vs5 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 5 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs7 = vfmacc_vv_f32m2 ( vs7 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 7 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs6 = vfmacc_vv_f32m2 ( vs6 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 6 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs8 = vfmacc_vv_f32m2 ( vs8 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 8 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs7 = vfmacc_vv_f32m2 ( vs7 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 7 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs9 = vfmacc_vv_f32m2 ( vs9 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 9 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs8 = vfmacc_vv_f32m2 ( vs8 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 8 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs10 = vfmacc_vv_f32m2 ( vs10 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 10 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs9 = vfmacc_vv_f32m2 ( vs9 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 9 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs11 = vfmacc_vv_f32m2 ( vs11 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 11 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs10 = vfmacc_vv_f32m2 ( vs10 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 10 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs12 = vfmacc_vv_f32m2 ( vs12 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 12 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs11 = vfmacc_vv_f32m2 ( vs11 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 11 , unroll_tail - 1 ) , vl ) , v , vl ) ;
vs13 = vfmacc_vv_f32m2 ( vs13 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 13 , mvl - 1 ) , vlm2 ) , v , vlm2 ) ;
vs12 = vfmacc_vv_f32m2 ( vs12 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 12 , unroll_tail - 1 ) , vl ) , v , vl ) ;
}
vs13 = vfmacc_vv_f32m2 ( vs13 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 13 , unroll_tail - 1 ) , vl ) , v , vl ) ;
int kvl = vecsize - k ;
if ( kvl > 0 ) {
vfloat32m2_t v = vle32_v_f32m2 ( vec + k , kvl ) ;
vs0 = vfmacc_vv_f32m2 ( vs0 , vle32_v_f32m2 ( wptr , kvl ) , v , kvl ) ;
vs1 = vfmacc_vv_f32m2 ( vs1 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 1 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs2 = vfmacc_vv_f32m2 ( vs2 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 2 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs3 = vfmacc_vv_f32m2 ( vs3 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 3 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs4 = vfmacc_vv_f32m2 ( vs4 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 4 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs5 = vfmacc_vv_f32m2 ( vs5 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 5 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs6 = vfmacc_vv_f32m2 ( vs6 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 6 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs7 = vfmacc_vv_f32m2 ( vs7 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 7 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs8 = vfmacc_vv_f32m2 ( vs8 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 8 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs9 = vfmacc_vv_f32m2 ( vs9 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 9 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs10 = vfmacc_vv_f32m2 ( vs10 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 10 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs11 = vfmacc_vv_f32m2 ( vs11 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 11 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs12 = vfmacc_vv_f32m2 ( vs12 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 12 , mvl - 1 ) , kvl ) , v , kvl ) ;
vs13 = vfmacc_vv_f32m2 ( vs13 , vle32_v_f32m2 ( wptr + wstep * std : : min ( 13 , mvl - 1 ) , kvl ) , v , kvl ) ;
}
}
// Calculate the sum of each vector
// Calculate the sum of each vector
float sum [ 14 ] ;
float sum [ 14 ] ;
vfloat32m1_t zero = vfmv_v_f_f32m1 ( 0 , vlm2 ) ;
vfloat32m1_t zero = vfmv_v_f_f32m1 ( 0 , vlm2 ) ;
@ -990,8 +950,8 @@ void fastGEMM1T( const float* vec, const float* weights,
sum [ 12 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m2_f32m1 ( zero , vs12 , zero , vlm2 ) ) ;
sum [ 12 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m2_f32m1 ( zero , vs12 , zero , vlm2 ) ) ;
sum [ 13 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m2_f32m1 ( zero , vs13 , zero , vlm2 ) ) ;
sum [ 13 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m2_f32m1 ( zero , vs13 , zero , vlm2 ) ) ;
vfloat32m4_t s0 = vfadd_vv_f32m4 ( vle32_v_f32m4 ( sum , mv l) , vle32_v_f32m4 ( bias + i , mvl ) , mv l) ;
vfloat32m4_t s0 = vfadd_vv_f32m4 ( vle32_v_f32m4 ( sum , unroll_tai l) , vle32_v_f32m4 ( bias + i , unroll_tail ) , unroll_tai l) ;
vse32_v_f32m4 ( dst + i , s0 , mv l) ;
vse32_v_f32m4 ( dst + i , s0 , unroll_tai l) ;
}
}
}
}
@ -1001,14 +961,14 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
int blockSize , int vecsize , int vecsize_aligned ,
int blockSize , int vecsize , int vecsize_aligned ,
const float * relu , bool initOutput )
const float * relu , bool initOutput )
{
{
int vl = FASCONV_BASE_VECSZ ;
const int vlm1 = vsetvlmax_e32m1 ( ) ;
int vlm1Max = vsetvlmax_e32m1 ( ) ;
int outCn = outShape [ 1 ] ;
int outCn = outShape [ 1 ] ;
size_t outPlaneSize = outShape [ 2 ] * outShape [ 3 ] ;
size_t outPlaneSize = outShape [ 2 ] * outShape [ 3 ] ;
// now compute dot product of the weights
// now compute dot product of the weights
// and im2row-transformed part of the tensor
// and im2row-transformed part of the tensor
for ( int i = 0 ; i < outCn ; i + = 3 )
for ( int i = 0 ; i < outCn ; i + = 3 )
{
{
int unroll_tail = FASCONV_BASE_VECSZ ;
const float * wptr0 = weights + i * wstep ;
const float * wptr0 = weights + i * wstep ;
const float * wptr1 = wptr0 + wstep ;
const float * wptr1 = wptr0 + wstep ;
const float * wptr2 = wptr1 + wstep ;
const float * wptr2 = wptr1 + wstep ;
@ -1033,20 +993,27 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
int j = 0 ;
int j = 0 ;
for ( ; j < blockSize ; j + = FASCONV_BASE_VECSZ )
for ( ; j < blockSize ; j + = FASCONV_BASE_VECSZ )
{
{
bool tail = false ;
const float * rptr = rowbuf + j * vecsize_aligned ;
const float * rptr1 = rptr + vecsize_aligned * 1 ,
* rptr2 = rptr + vecsize_aligned * 2 ,
* rptr3 = rptr + vecsize_aligned * 3 ,
* rptr4 = rptr + vecsize_aligned * 4 ,
* rptr5 = rptr + vecsize_aligned * 5 ,
* rptr6 = rptr + vecsize_aligned * 6 ,
* rptr7 = rptr + vecsize_aligned * 7 ;
if ( j + FASCONV_BASE_VECSZ > blockSize )
if ( j + FASCONV_BASE_VECSZ > blockSize )
{
{
if ( j = = 0 ) {
unroll_tail = blockSize - j ;
vl = blockSize ;
rptr1 = rptr + vecsize_aligned * std : : min ( 1 , unroll_tail - 1 ) ,
}
rptr2 = rptr + vecsize_aligned * std : : min ( 2 , unroll_tail - 1 ) ,
else {
rptr3 = rptr + vecsize_aligned * std : : min ( 3 , unroll_tail - 1 ) ,
j = blockSize - FASCONV_BASE_VECSZ ;
rptr4 = rptr + vecsize_aligned * std : : min ( 4 , unroll_tail - 1 ) ,
tail = true ;
rptr5 = rptr + vecsize_aligned * std : : min ( 5 , unroll_tail - 1 ) ,
}
rptr6 = rptr + vecsize_aligned * std : : min ( 6 , unroll_tail - 1 ) ,
rptr7 = rptr + vecsize_aligned * std : : min ( 7 , unroll_tail - 1 ) ;
}
}
int k = 0 ;
const float * rptr = rowbuf + j * vecsize_aligned ;
int vl , avl = vecsize ;
int vlm1 = vsetvlmax_e32m1 ( ) ;
vfloat32m1_t
vfloat32m1_t
vs00 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs10 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs20 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ,
vs00 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs10 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs20 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ,
vs01 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs11 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs21 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ,
vs01 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs11 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs21 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ,
@ -1057,107 +1024,107 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
vs06 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs16 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs26 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ,
vs06 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs16 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs26 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ,
vs07 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs17 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs27 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ;
vs07 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs17 = vfmv_v_f_f32m1 ( 0 , vlm1 ) , vs27 = vfmv_v_f_f32m1 ( 0 , vlm1 ) ;
for ( ; k < vecsize ; k + = vlm1 , rptr + = vlm1 )
for ( int k = 0 ; k < vecsize ; k + = vl , avl - = vl )
{
{
if ( k + vlm1 > = vecsize ) {
vl = vsetvl_e32m1 ( avl ) ;
vlm1 = vecsize - k ;
vfloat32m1_t w0 = vle32_v_f32m1 ( wptr0 + k , vl ) ;
}
vfloat32m1_t w1 = vle32_v_f32m1 ( wptr1 + k , vl ) ;
vfloat32m1_t w0 = vle32_v_f32m1 ( wptr0 + k , vlm1 ) ;
vfloat32m1_t w2 = vle32_v_f32m1 ( wptr2 + k , vl ) ;
vfloat32m1_t w1 = vle32_v_f32m1 ( wptr1 + k , vlm1 ) ;
vfloat32m1_t r0 = vle32_v_f32m1 ( rptr , vl ) ;
vfloat32m1_t w2 = vle32_v_f32m1 ( wptr2 + k , vlm1 ) ;
vfloat32m1_t r0 = vle32_v_f32m1 ( rptr , vlm1 ) ;
vs00 = vfmacc_vv_f32m1 ( vs00 , w0 , r0 , vl ) ;
vs10 = vfmacc_vv_f32m1 ( vs10 , w1 , r0 , vl ) ;
vs00 = vfmacc_vv_f32m1 ( vs00 , w0 , r0 , vlm1 ) ;
vs20 = vfmacc_vv_f32m1 ( vs20 , w2 , r0 , vl ) ;
vs10 = vfmacc_vv_f32m1 ( vs10 , w1 , r0 , vlm1 ) ;
vs20 = vfmacc_vv_f32m1 ( vs20 , w2 , r0 , vlm1 ) ;
r0 = vle32_v_f32m1 ( rptr1 , vl ) ;
vs01 = vfmacc_vv_f32m1 ( vs01 , w0 , r0 , vl ) ;
r0 = vle32_v_f32m1 ( rptr + vecsize_aligned , vlm1 ) ;
vs11 = vfmacc_vv_f32m1 ( vs11 , w1 , r0 , vl ) ;
vs01 = vfmacc_vv_f32m1 ( vs01 , w0 , r0 , vlm1 ) ;
vs21 = vfmacc_vv_f32m1 ( vs21 , w2 , r0 , vl ) ;
vs11 = vfmacc_vv_f32m1 ( vs11 , w1 , r0 , vlm1 ) ;
vs21 = vfmacc_vv_f32m1 ( vs21 , w2 , r0 , vlm1 ) ;
r0 = vle32_v_f32m1 ( rptr2 , vl ) ;
vs02 = vfmacc_vv_f32m1 ( vs02 , w0 , r0 , vl ) ;
r0 = vle32_v_f32m1 ( rptr + vecsize_aligned * 2 , vlm1 ) ;
vs12 = vfmacc_vv_f32m1 ( vs12 , w1 , r0 , vl ) ;
vs02 = vfmacc_vv_f32m1 ( vs02 , w0 , r0 , vlm1 ) ;
vs22 = vfmacc_vv_f32m1 ( vs22 , w2 , r0 , vl ) ;
vs12 = vfmacc_vv_f32m1 ( vs12 , w1 , r0 , vlm1 ) ;
vs22 = vfmacc_vv_f32m1 ( vs22 , w2 , r0 , vlm1 ) ;
r0 = vle32_v_f32m1 ( rptr3 , vl ) ;
vs03 = vfmacc_vv_f32m1 ( vs03 , w0 , r0 , vl ) ;
r0 = vle32_v_f32m1 ( rptr + vecsize_aligned * 3 , vlm1 ) ;
vs13 = vfmacc_vv_f32m1 ( vs13 , w1 , r0 , vl ) ;
vs03 = vfmacc_vv_f32m1 ( vs03 , w0 , r0 , vlm1 ) ;
vs23 = vfmacc_vv_f32m1 ( vs23 , w2 , r0 , vl ) ;
vs13 = vfmacc_vv_f32m1 ( vs13 , w1 , r0 , vlm1 ) ;
vs23 = vfmacc_vv_f32m1 ( vs23 , w2 , r0 , vlm1 ) ;
r0 = vle32_v_f32m1 ( rptr4 , vl ) ;
vs04 = vfmacc_vv_f32m1 ( vs04 , w0 , r0 , vl ) ;
r0 = vle32_v_f32m1 ( rptr + vecsize_aligned * 4 , vlm1 ) ;
vs14 = vfmacc_vv_f32m1 ( vs14 , w1 , r0 , vl ) ;
vs04 = vfmacc_vv_f32m1 ( vs04 , w0 , r0 , vlm1 ) ;
vs24 = vfmacc_vv_f32m1 ( vs24 , w2 , r0 , vl ) ;
vs14 = vfmacc_vv_f32m1 ( vs14 , w1 , r0 , vlm1 ) ;
vs24 = vfmacc_vv_f32m1 ( vs24 , w2 , r0 , vlm1 ) ;
r0 = vle32_v_f32m1 ( rptr5 , vl ) ;
vs05 = vfmacc_vv_f32m1 ( vs05 , w0 , r0 , vl ) ;
r0 = vle32_v_f32m1 ( rptr + vecsize_aligned * 5 , vlm1 ) ;
vs15 = vfmacc_vv_f32m1 ( vs15 , w1 , r0 , vl ) ;
vs05 = vfmacc_vv_f32m1 ( vs05 , w0 , r0 , vlm1 ) ;
vs25 = vfmacc_vv_f32m1 ( vs25 , w2 , r0 , vl ) ;
vs15 = vfmacc_vv_f32m1 ( vs15 , w1 , r0 , vlm1 ) ;
vs25 = vfmacc_vv_f32m1 ( vs25 , w2 , r0 , vlm1 ) ;
r0 = vle32_v_f32m1 ( rptr6 , vl ) ;
vs06 = vfmacc_vv_f32m1 ( vs06 , w0 , r0 , vl ) ;
r0 = vle32_v_f32m1 ( rptr + vecsize_aligned * 6 , vlm1 ) ;
vs16 = vfmacc_vv_f32m1 ( vs16 , w1 , r0 , vl ) ;
vs06 = vfmacc_vv_f32m1 ( vs06 , w0 , r0 , vlm1 ) ;
vs26 = vfmacc_vv_f32m1 ( vs26 , w2 , r0 , vl ) ;
vs16 = vfmacc_vv_f32m1 ( vs16 , w1 , r0 , vlm1 ) ;
vs26 = vfmacc_vv_f32m1 ( vs26 , w2 , r0 , vlm1 ) ;
r0 = vle32_v_f32m1 ( rptr7 , vl ) ;
vs07 = vfmacc_vv_f32m1 ( vs07 , w0 , r0 , vl ) ;
r0 = vle32_v_f32m1 ( rptr + vecsize_aligned * 7 , vlm1 ) ;
vs17 = vfmacc_vv_f32m1 ( vs17 , w1 , r0 , vl ) ;
vs07 = vfmacc_vv_f32m1 ( vs07 , w0 , r0 , vlm1 ) ;
vs27 = vfmacc_vv_f32m1 ( vs27 , w2 , r0 , vl ) ;
vs17 = vfmacc_vv_f32m1 ( vs17 , w1 , r0 , vlm1 ) ;
vs27 = vfmacc_vv_f32m1 ( vs27 , w2 , r0 , vlm1 ) ;
rptr + = vl ; rptr1 + = vl ; rptr2 + = vl ; rptr3 + = vl ;
rptr4 + = vl ; rptr5 + = vl ; rptr6 + = vl ; rptr7 + = vl ;
}
}
// compute sum of each vs
// compute sum of each vs
vfloat32m1_t zero = vfmv_v_f_f32m1 ( 0 , vlm1Max ) ;
vfloat32m1_t zero = vfmv_v_f_f32m1 ( 0 , vlm1 ) ;
// vl is required here to be at least FASCONV_BASE_VECSZ, aka 8.
// unroll_tail( vl) is required here to be at least FASCONV_BASE_VECSZ, aka 8.
float sum0 [ FASCONV_BASE_VECSZ ] , sum1 [ FASCONV_BASE_VECSZ ] , sum2 [ FASCONV_BASE_VECSZ ] ;
float sum0 [ FASCONV_BASE_VECSZ ] , sum1 [ FASCONV_BASE_VECSZ ] , sum2 [ FASCONV_BASE_VECSZ ] ;
sum0 [ 0 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs00 , zero , vlm1Max ) ) ;
sum0 [ 0 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs00 , zero , vlm1 ) ) ;
sum0 [ 1 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs01 , zero , vlm1Max ) ) ;
sum0 [ 1 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs01 , zero , vlm1 ) ) ;
sum0 [ 2 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs02 , zero , vlm1Max ) ) ;
sum0 [ 2 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs02 , zero , vlm1 ) ) ;
sum0 [ 3 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs03 , zero , vlm1Max ) ) ;
sum0 [ 3 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs03 , zero , vlm1 ) ) ;
sum0 [ 4 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs04 , zero , vlm1Max ) ) ;
sum0 [ 4 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs04 , zero , vlm1 ) ) ;
sum0 [ 5 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs05 , zero , vlm1Max ) ) ;
sum0 [ 5 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs05 , zero , vlm1 ) ) ;
sum0 [ 6 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs06 , zero , vlm1Max ) ) ;
sum0 [ 6 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs06 , zero , vlm1 ) ) ;
sum0 [ 7 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs07 , zero , vlm1Max ) ) ;
sum0 [ 7 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs07 , zero , vlm1 ) ) ;
sum1 [ 0 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs10 , zero , vlm1Max ) ) ;
sum1 [ 0 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs10 , zero , vlm1 ) ) ;
sum1 [ 1 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs11 , zero , vlm1Max ) ) ;
sum1 [ 1 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs11 , zero , vlm1 ) ) ;
sum1 [ 2 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs12 , zero , vlm1Max ) ) ;
sum1 [ 2 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs12 , zero , vlm1 ) ) ;
sum1 [ 3 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs13 , zero , vlm1Max ) ) ;
sum1 [ 3 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs13 , zero , vlm1 ) ) ;
sum1 [ 4 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs14 , zero , vlm1Max ) ) ;
sum1 [ 4 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs14 , zero , vlm1 ) ) ;
sum1 [ 5 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs15 , zero , vlm1Max ) ) ;
sum1 [ 5 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs15 , zero , vlm1 ) ) ;
sum1 [ 6 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs16 , zero , vlm1Max ) ) ;
sum1 [ 6 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs16 , zero , vlm1 ) ) ;
sum1 [ 7 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs17 , zero , vlm1Max ) ) ;
sum1 [ 7 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs17 , zero , vlm1 ) ) ;
sum2 [ 0 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs20 , zero , vlm1Max ) ) ;
sum2 [ 0 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs20 , zero , vlm1 ) ) ;
sum2 [ 1 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs21 , zero , vlm1Max ) ) ;
sum2 [ 1 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs21 , zero , vlm1 ) ) ;
sum2 [ 2 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs22 , zero , vlm1Max ) ) ;
sum2 [ 2 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs22 , zero , vlm1 ) ) ;
sum2 [ 3 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs23 , zero , vlm1Max ) ) ;
sum2 [ 3 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs23 , zero , vlm1 ) ) ;
sum2 [ 4 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs24 , zero , vlm1Max ) ) ;
sum2 [ 4 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs24 , zero , vlm1 ) ) ;
sum2 [ 5 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs25 , zero , vlm1Max ) ) ;
sum2 [ 5 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs25 , zero , vlm1 ) ) ;
sum2 [ 6 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs26 , zero , vlm1Max ) ) ;
sum2 [ 6 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs26 , zero , vlm1 ) ) ;
sum2 [ 7 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs27 , zero , vlm1Max ) ) ;
sum2 [ 7 ] = vfmv_f_s_f32m1_f32 ( vfredosum_vs_f32m1_f32m1 ( zero , vs27 , zero , vlm1 ) ) ;
// if VLEN = 128, so LMUL = 2 for vl = 8.
// if VLEN = 128, so LMUL = 2 for unroll_tail( vl) = 8.
// otherwise, VLEN >=256, we only use fist 8 element of the vReg.
// otherwise, VLEN >=256, we only use fist 8 element of the vReg.
vfloat32m2_t s0 , s1 , s2 ;
vfloat32m2_t s0 , s1 , s2 ;
if ( initOutput )
if ( initOutput )
{
{
s0 = vfmv_v_f_f32m2 ( bias0 , v l) ;
s0 = vfmv_v_f_f32m2 ( bias0 , unroll_tai l) ;
s1 = vfmv_v_f_f32m2 ( bias1 , v l) ;
s1 = vfmv_v_f_f32m2 ( bias1 , unroll_tai l) ;
s2 = vfmv_v_f_f32m2 ( bias2 , v l) ;
s2 = vfmv_v_f_f32m2 ( bias2 , unroll_tai l) ;
}
}
else
else
{
{
s0 = vle32_v_f32m2 ( outptr0 + j , v l) ;
s0 = vle32_v_f32m2 ( outptr0 + j , unroll_tai l) ;
s1 = vle32_v_f32m2 ( outptr1 + j , v l) ;
s1 = vle32_v_f32m2 ( outptr1 + j , unroll_tai l) ;
s2 = vle32_v_f32m2 ( outptr2 + j , v l) ;
s2 = vle32_v_f32m2 ( outptr2 + j , unroll_tai l) ;
}
}
s0 = vfadd_vv_f32m2 ( vle32_v_f32m2 ( sum0 , v l) , s0 , v l) ;
s0 = vfadd_vv_f32m2 ( vle32_v_f32m2 ( sum0 , unroll_tai l) , s0 , unroll_tai l) ;
s1 = vfadd_vv_f32m2 ( vle32_v_f32m2 ( sum1 , v l) , s1 , v l) ;
s1 = vfadd_vv_f32m2 ( vle32_v_f32m2 ( sum1 , unroll_tai l) , s1 , unroll_tai l) ;
s2 = vfadd_vv_f32m2 ( vle32_v_f32m2 ( sum2 , v l) , s2 , v l) ;
s2 = vfadd_vv_f32m2 ( vle32_v_f32m2 ( sum2 , unroll_tai l) , s2 , unroll_tai l) ;
if ( relu )
if ( relu )
{
{
vfloat32m2_t vr0 = vfmv_v_f_f32m2 ( 1 , vl ) , vr1 = vfmv_v_f_f32m2 ( 1 , vl ) , vr2 = vfmv_v_f_f32m2 ( 1 , vl ) ;
float r0 = relu [ i ] , r1 = relu [ i + 1 ] , r2 = relu [ i + 2 ] ;
float r0 = relu [ i ] , r1 = relu [ i + 1 ] , r2 = relu [ i + 2 ] ;
if ( i + 2 > = outCn )
if ( i + 2 > = outCn )
{
{
@ -1165,33 +1132,17 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
if ( i + 1 > = outCn )
if ( i + 1 > = outCn )
r2 = r1 = r0 ;
r2 = r1 = r0 ;
}
}
vr0 = vfmv_v_f_f32m2 ( r0 , vl ) ;
vbool16_t m0 = vmfgt_vf_f32m2_b16 ( s0 , 0 , unroll_tail ) ;
vr1 = vfmv_v_f_f32m2 ( r1 , vl ) ;
vbool16_t m1 = vmfgt_vf_f32m2_b16 ( s1 , 0 , unroll_tail ) ;
vr2 = vfmv_v_f_f32m2 ( r2 , vl ) ;
vbool16_t m2 = vmfgt_vf_f32m2_b16 ( s2 , 0 , unroll_tail ) ;
vbool16_t m0 = vmfgt_vf_f32m2_b16 ( s0 , 0 , vl ) ;
s0 = vmerge_vvm_f32m2 ( m0 , vfmul_vf_f32m2 ( s0 , r0 , unroll_tail ) , s0 , unroll_tail ) ;
vbool16_t m1 = vmfgt_vf_f32m2_b16 ( s1 , 0 , vl ) ;
s1 = vmerge_vvm_f32m2 ( m1 , vfmul_vf_f32m2 ( s1 , r1 , unroll_tail ) , s1 , unroll_tail ) ;
vbool16_t m2 = vmfgt_vf_f32m2_b16 ( s2 , 0 , vl ) ;
s2 = vmerge_vvm_f32m2 ( m2 , vfmul_vf_f32m2 ( s2 , r2 , unroll_tail ) , s2 , unroll_tail ) ;
s0 = vmerge_vvm_f32m2 ( m0 , vfmul_vv_f32m2 ( s0 , vr0 , vl ) , s0 , vl ) ;
s1 = vmerge_vvm_f32m2 ( m1 , vfmul_vv_f32m2 ( s1 , vr1 , vl ) , s1 , vl ) ;
s2 = vmerge_vvm_f32m2 ( m2 , vfmul_vv_f32m2 ( s2 , vr2 , vl ) , s2 , vl ) ;
}
if ( tail )
{
int maskbuf [ FASCONV_BASE_VECSZ ] = { 0 } ;
int rsz = blockSize % FASCONV_BASE_VECSZ ;
for ( int i = 0 ; i < rsz ; i + + )
maskbuf [ FASCONV_BASE_VECSZ - i - 1 ] = - 1 ;
vint32m2_t vmaskbuf = vle32_v_i32m2 ( maskbuf , vl ) ;
vbool16_t mask = vmslt_vx_i32m2_b16 ( vmaskbuf , 0 , vl ) ; // mask for tail
s0 = vmerge_vvm_f32m2 ( mask , vle32_v_f32m2 ( outptr0 + j , vl ) , s0 , vl ) ;
s1 = vmerge_vvm_f32m2 ( mask , vle32_v_f32m2 ( outptr1 + j , vl ) , s1 , vl ) ;
s2 = vmerge_vvm_f32m2 ( mask , vle32_v_f32m2 ( outptr2 + j , vl ) , s2 , vl ) ;
}
}
vse32_v_f32m2 ( outptr0 + j , s0 , v l) ;
vse32_v_f32m2 ( outptr0 + j , s0 , unroll_tail ) ;
vse32_v_f32m2 ( outptr1 + j , s1 , v l) ;
vse32_v_f32m2 ( outptr1 + j , s1 , unroll_tail ) ;
vse32_v_f32m2 ( outptr2 + j , s2 , v l) ;
vse32_v_f32m2 ( outptr2 + j , s2 , unroll_tail ) ;
}
}
}
}
}
}
@ -1236,7 +1187,7 @@ void fastDepthwiseConv( const float* wptr,
float * outptr_ ,
float * outptr_ ,
int out_d , int outH , int outW )
int out_d , int outH , int outW )
{
{
int vl = vsetvlmax_e32m2 ( ) ;
int vl ;
const float w00_ = wptr [ 0 ] , w01_ = wptr [ 1 ] , w02_ = wptr [ 2 ] ,
const float w00_ = wptr [ 0 ] , w01_ = wptr [ 1 ] , w02_ = wptr [ 2 ] ,
w10 = wptr [ 3 ] , w11 = wptr [ 4 ] , w12 = wptr [ 5 ] ,
w10 = wptr [ 3 ] , w11 = wptr [ 4 ] , w12 = wptr [ 5 ] ,
w20_ = wptr [ 6 ] , w21_ = wptr [ 7 ] , w22_ = wptr [ 8 ] ;
w20_ = wptr [ 6 ] , w21_ = wptr [ 7 ] , w22_ = wptr [ 8 ] ;
@ -1275,11 +1226,11 @@ void fastDepthwiseConv( const float* wptr,
if ( stride_w = = 1 | | ( stride_w = = 2 & & dilation_w = = 1 ) )
if ( stride_w = = 1 | | ( stride_w = = 2 & & dilation_w = = 1 ) )
{
{
int avl = outW1 - out_j ;
if ( stride_w = = 1 )
if ( stride_w = = 1 )
for ( ; out_j < outW1 ; out_j + = vl )
for ( ; out_j < outW1 ; out_j + = vl , avl - = vl )
{
{
if ( out_j + vl > outW1 )
vl = vsetvl_e32m2 ( avl ) ;
vl = outW1 - out_j ;
int in_j = out_j * stride_w - pad_l ;
int in_j = out_j * stride_w - pad_l ;
vfloat32m2_t v00 = vle32_v_f32m2 ( imgptr0 + in_j , vl ) ,
vfloat32m2_t v00 = vle32_v_f32m2 ( imgptr0 + in_j , vl ) ,
v01 = vle32_v_f32m2 ( imgptr0 + in_j + dilation_w , vl ) ,
v01 = vle32_v_f32m2 ( imgptr0 + in_j + dilation_w , vl ) ,
@ -1313,10 +1264,9 @@ void fastDepthwiseConv( const float* wptr,
vse32_v_f32m2 ( outptr + out_j , vout0 , vl ) ;
vse32_v_f32m2 ( outptr + out_j , vout0 , vl ) ;
}
}
else //stride_w == 2 && dilation_w == 1
else //stride_w == 2 && dilation_w == 1
for ( ; out_j < outW1 ; out_j + = vl )
for ( ; out_j < outW1 ; out_j + = vl , avl - = vl )
{
{
if ( out_j + vl > outW1 )
vl = vsetvl_e32m2 ( avl ) ;
vl = outW1 - out_j ;
int in_j = out_j * stride_w - pad_l ;
int in_j = out_j * stride_w - pad_l ;
vfloat32m2_t v00 , v01 , v02 , v10 , v11 , v12 , v20 , v21 , v22 , unused ;
vfloat32m2_t v00 , v01 , v02 , v10 , v11 , v12 , v20 , v21 , v22 , unused ;
vfloat32m2_load_deinterleave ( imgptr0 + in_j , v00 , v01 , vl ) ;
vfloat32m2_load_deinterleave ( imgptr0 + in_j , v00 , v01 , vl ) ;