Further optimize fastDepthwiseConv for RVV.

pull/25361/head
Liutong HAN 1 year ago
parent 5121a1bf0d
commit 5be158a2b6
  1. 100
      modules/dnn/src/layers/cpu_kernels/conv_depthwise.simd.hpp

@ -209,34 +209,6 @@ void fastDepthwiseConv( const float* wptr,
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_RVV
/*
Example for load_deinterleave:
input: ptr[16] = {1,2,3, ... ,14,15,16}
output: a = {1, 3, 5, 7, 9, 11, 13, 15}
output: b = {2, 4, 6, 8,10, 12, 14, 16}
*/
static inline void vfloat32m2_load_deinterleave(const float* ptr, vfloat32m2_t& a, vfloat32m2_t& b, int vl)
{
vuint64m4_t mask = vmv_v_x_u64m4(1,vl*2);
vuint32m4_t mask_re = vreinterpret_v_u64m4_u32m4(mask);
vbool8_t mask0 = vmseq_vx_u32m4_b8 (mask_re, 1, vl*2);
vbool8_t mask1 = vmseq_vx_u32m4_b8 (mask_re, 0, vl*2);
vfloat32m4_t tempa = vundefined_f32m4(), tempb = vundefined_f32m4();
vfloat32m4_t vw = vle32_v_f32m4(ptr, vl*2);
tempa = vcompress_vm_f32m4(mask0, tempa, vw, vl*2);
tempb = vcompress_vm_f32m4(mask1, tempb, vw, vl*2);
/* The following instructions have not to be supported by the GNU toolchain.
So we temporarily use store and load instead.
// a = vlmul_trunc_v_f32m4_f32m2(tempa);
// b = vlmul_trunc_v_f32m4_f32m2(tempb);
*/
cv::AutoBuffer<float> cvBuffer(sizeof(float)*vl*2);
float* buffer = (float*)cvBuffer.data();
vse32_v_f32m4(buffer, tempa, vl);
a = vle32_v_f32m2(buffer, vl);
vse32_v_f32m4(buffer, tempb, vl);
b = vle32_v_f32m2(buffer, vl);
}
void fastDepthwiseConv( const float* wptr,
int kernel_h, int kernel_w,
@ -292,64 +264,40 @@ void fastDepthwiseConv( const float* wptr,
if( stride_w == 1 )
for( ; out_j < outW1; out_j += vl, avl -= vl)
{
vl = vsetvl_e32m2(avl);
vl = vsetvl_e32m8(avl);
int in_j = out_j * stride_w - pad_l;
vfloat32m2_t v00 = vle32_v_f32m2(imgptr0 + in_j, vl),
v01 = vle32_v_f32m2(imgptr0 + in_j + dilation_w, vl),
v02 = vle32_v_f32m2(imgptr0 + in_j + dilation_w*2, vl),
v10 = vle32_v_f32m2(imgptr1 + in_j, vl),
v11 = vle32_v_f32m2(imgptr1 + in_j + dilation_w, vl),
v12 = vle32_v_f32m2(imgptr1 + in_j + dilation_w*2, vl),
v20 = vle32_v_f32m2(imgptr2 + in_j, vl),
v21 = vle32_v_f32m2(imgptr2 + in_j + dilation_w, vl),
v22 = vle32_v_f32m2(imgptr2 + in_j + dilation_w*2, vl);
vfloat32m2_t vout0 = vfmul_vf_f32m2(v00, w00, vl);
vfloat32m2_t vout1 = vfmul_vf_f32m2(v01, w01, vl);
vfloat32m2_t vout2 = vfmul_vf_f32m2(v02, w02, vl);
vout0 = vfadd_vf_f32m2(vout0, bias, vl);
vout0 = vfmacc_vf_f32m2(vout0, w10, v10, vl);
vout1 = vfmacc_vf_f32m2(vout1, w11, v11, vl);
vout2 = vfmacc_vf_f32m2(vout2, w12, v12, vl);
vout0 = vfmacc_vf_f32m2(vout0, w20, v20, vl);
vout1 = vfmacc_vf_f32m2(vout1, w21, v21, vl);
vout2 = vfmacc_vf_f32m2(vout2, w22, v22, vl);
vout0 = vfadd_vv_f32m2(vfadd_vv_f32m2(vout0, vout1, vl), vout2, vl);
vfloat32m8_t vout0 = vfmacc_vf_f32m8(vfmv_v_f_f32m8(bias, vl), w00, vle32_v_f32m8(imgptr0 + in_j, vl), vl);
vout0 = vfmacc_vf_f32m8(vout0, w01, vle32_v_f32m8(imgptr0 + in_j + dilation_w, vl), vl);
vout0 = vfmacc_vf_f32m8(vout0, w02, vle32_v_f32m8(imgptr0 + in_j + dilation_w*2, vl), vl);
vout0 = vfmacc_vf_f32m8(vout0, w10, vle32_v_f32m8(imgptr1 + in_j, vl),vl);
vout0 = vfmacc_vf_f32m8(vout0, w11, vle32_v_f32m8(imgptr1 + in_j + dilation_w, vl),vl);
vout0 = vfmacc_vf_f32m8(vout0, w12, vle32_v_f32m8(imgptr1 + in_j + dilation_w*2, vl),vl);
vout0 = vfmacc_vf_f32m8(vout0, w20, vle32_v_f32m8(imgptr2 + in_j, vl), vl);
vout0 = vfmacc_vf_f32m8(vout0, w21, vle32_v_f32m8(imgptr2 + in_j + dilation_w, vl), vl);
vout0 = vfmacc_vf_f32m8(vout0, w22, vle32_v_f32m8(imgptr2 + in_j + dilation_w*2, vl), vl);
if (relu)
{
vbool16_t m = vmfgt_vf_f32m2_b16(vout0, 0, vl);
vout0 = vmerge_vvm_f32m2(m, vfmul_vf_f32m2(vout0, relu_coeff, vl), vout0, vl);
vbool4_t m = vmfgt_vf_f32m8_b4(vout0, 0, vl);
vout0 = vmerge_vvm_f32m8(m, vfmul_vf_f32m8(vout0, relu_coeff, vl), vout0, vl);
}
vse32_v_f32m2(outptr + out_j, vout0, vl);
vse32_v_f32m8(outptr + out_j, vout0, vl);
}
else //stride_w == 2 && dilation_w == 1
for( ; out_j < outW1; out_j += vl, avl -= vl)
{
vl = vsetvl_e32m2(avl);
int in_j = out_j * stride_w - pad_l;
vfloat32m2_t v00, v01, v02, v10, v11, v12, v20, v21, v22, unused;
vfloat32m2_load_deinterleave(imgptr0 + in_j, v00, v01, vl);
vfloat32m2_load_deinterleave(imgptr0 + in_j + 2, v02, unused, vl);
vfloat32m2_load_deinterleave(imgptr1 + in_j, v10, v11, vl);
vfloat32m2_load_deinterleave(imgptr1 + in_j + 2, v12, unused, vl);
vfloat32m2_load_deinterleave(imgptr2 + in_j, v20, v21, vl);
vfloat32m2_load_deinterleave(imgptr2 + in_j + 2, v22, unused, vl);
vfloat32m2_t vout0 = vfmul_vf_f32m2(v00, w00, vl);
vfloat32m2_t vout1 = vfmul_vf_f32m2(v01, w01, vl);
vfloat32m2_t vout2 = vfmul_vf_f32m2(v02, w02, vl);
vout0 = vfadd_vf_f32m2(vout0, bias, vl);
vout0 = vfmacc_vf_f32m2(vout0, w10, v10, vl);
vout1 = vfmacc_vf_f32m2(vout1, w11, v11, vl);
vout2 = vfmacc_vf_f32m2(vout2, w12, v12, vl);
vout0 = vfmacc_vf_f32m2(vout0, w20, v20, vl);
vout1 = vfmacc_vf_f32m2(vout1, w21, v21, vl);
vout2 = vfmacc_vf_f32m2(vout2, w22, v22, vl);
vfloat32m2_t vout0 = vfmacc_vf_f32m2(vfmv_v_f_f32m2(bias, vl), w00, vlse32_v_f32m2(imgptr0+in_j , 8, vl), vl);
vfloat32m2_t vout1 = vfmul_vf_f32m2(vlse32_v_f32m2(imgptr0+in_j+1, 8, vl), w01, vl);
vfloat32m2_t vout2 = vfmul_vf_f32m2(vlse32_v_f32m2(imgptr0+in_j+2, 8, vl), w02, vl);
vout0 = vfmacc_vf_f32m2(vout0, w10, vlse32_v_f32m2(imgptr1+in_j , 8, vl), vl);
vout1 = vfmacc_vf_f32m2(vout1, w11, vlse32_v_f32m2(imgptr1+in_j+1, 8, vl), vl);
vout2 = vfmacc_vf_f32m2(vout2, w12, vlse32_v_f32m2(imgptr1+in_j+2, 8, vl), vl);
vout0 = vfmacc_vf_f32m2(vout0, w20, vlse32_v_f32m2(imgptr2+in_j , 8, vl), vl);
vout1 = vfmacc_vf_f32m2(vout1, w21, vlse32_v_f32m2(imgptr2+in_j+1, 8, vl), vl);
vout2 = vfmacc_vf_f32m2(vout2, w22, vlse32_v_f32m2(imgptr2+in_j+2, 8, vl), vl);
vout0 = vfadd_vv_f32m2(vfadd_vv_f32m2(vout0, vout1, vl), vout2, vl);
if (relu)

Loading…
Cancel
Save