From 5be158a2b6ed0f4f4def851d3a57c3e3b5865ad5 Mon Sep 17 00:00:00 2001 From: Liutong HAN Date: Sun, 7 Apr 2024 11:34:41 +0800 Subject: [PATCH] Further optimize fastDepthwiseConv for RVV. --- .../cpu_kernels/conv_depthwise.simd.hpp | 100 +++++------------- 1 file changed, 24 insertions(+), 76 deletions(-) diff --git a/modules/dnn/src/layers/cpu_kernels/conv_depthwise.simd.hpp b/modules/dnn/src/layers/cpu_kernels/conv_depthwise.simd.hpp index 1d561e9864..6d4b211b8c 100644 --- a/modules/dnn/src/layers/cpu_kernels/conv_depthwise.simd.hpp +++ b/modules/dnn/src/layers/cpu_kernels/conv_depthwise.simd.hpp @@ -209,34 +209,6 @@ void fastDepthwiseConv( const float* wptr, #if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_RVV -/* -Example for load_deinterleave: - input: ptr[16] = {1,2,3, ... ,14,15,16} - output: a = {1, 3, 5, 7, 9, 11, 13, 15} - output: b = {2, 4, 6, 8,10, 12, 14, 16} -*/ -static inline void vfloat32m2_load_deinterleave(const float* ptr, vfloat32m2_t& a, vfloat32m2_t& b, int vl) -{ - vuint64m4_t mask = vmv_v_x_u64m4(1,vl*2); - vuint32m4_t mask_re = vreinterpret_v_u64m4_u32m4(mask); - vbool8_t mask0 = vmseq_vx_u32m4_b8 (mask_re, 1, vl*2); - vbool8_t mask1 = vmseq_vx_u32m4_b8 (mask_re, 0, vl*2); - vfloat32m4_t tempa = vundefined_f32m4(), tempb = vundefined_f32m4(); - vfloat32m4_t vw = vle32_v_f32m4(ptr, vl*2); - tempa = vcompress_vm_f32m4(mask0, tempa, vw, vl*2); - tempb = vcompress_vm_f32m4(mask1, tempb, vw, vl*2); - /* The following instructions have not to be supported by the GNU toolchain. - So we temporarily use store and load instead. - // a = vlmul_trunc_v_f32m4_f32m2(tempa); - // b = vlmul_trunc_v_f32m4_f32m2(tempb); - */ - cv::AutoBuffer cvBuffer(sizeof(float)*vl*2); - float* buffer = (float*)cvBuffer.data(); - vse32_v_f32m4(buffer, tempa, vl); - a = vle32_v_f32m2(buffer, vl); - vse32_v_f32m4(buffer, tempb, vl); - b = vle32_v_f32m2(buffer, vl); -} void fastDepthwiseConv( const float* wptr, int kernel_h, int kernel_w, @@ -292,64 +264,40 @@ void fastDepthwiseConv( const float* wptr, if( stride_w == 1 ) for( ; out_j < outW1; out_j += vl, avl -= vl) { - vl = vsetvl_e32m2(avl); + vl = vsetvl_e32m8(avl); int in_j = out_j * stride_w - pad_l; - vfloat32m2_t v00 = vle32_v_f32m2(imgptr0 + in_j, vl), - v01 = vle32_v_f32m2(imgptr0 + in_j + dilation_w, vl), - v02 = vle32_v_f32m2(imgptr0 + in_j + dilation_w*2, vl), - v10 = vle32_v_f32m2(imgptr1 + in_j, vl), - v11 = vle32_v_f32m2(imgptr1 + in_j + dilation_w, vl), - v12 = vle32_v_f32m2(imgptr1 + in_j + dilation_w*2, vl), - v20 = vle32_v_f32m2(imgptr2 + in_j, vl), - v21 = vle32_v_f32m2(imgptr2 + in_j + dilation_w, vl), - v22 = vle32_v_f32m2(imgptr2 + in_j + dilation_w*2, vl); - - vfloat32m2_t vout0 = vfmul_vf_f32m2(v00, w00, vl); - vfloat32m2_t vout1 = vfmul_vf_f32m2(v01, w01, vl); - vfloat32m2_t vout2 = vfmul_vf_f32m2(v02, w02, vl); - vout0 = vfadd_vf_f32m2(vout0, bias, vl); - - vout0 = vfmacc_vf_f32m2(vout0, w10, v10, vl); - vout1 = vfmacc_vf_f32m2(vout1, w11, v11, vl); - vout2 = vfmacc_vf_f32m2(vout2, w12, v12, vl); - - vout0 = vfmacc_vf_f32m2(vout0, w20, v20, vl); - vout1 = vfmacc_vf_f32m2(vout1, w21, v21, vl); - vout2 = vfmacc_vf_f32m2(vout2, w22, v22, vl); - - vout0 = vfadd_vv_f32m2(vfadd_vv_f32m2(vout0, vout1, vl), vout2, vl); + vfloat32m8_t vout0 = vfmacc_vf_f32m8(vfmv_v_f_f32m8(bias, vl), w00, vle32_v_f32m8(imgptr0 + in_j, vl), vl); + vout0 = vfmacc_vf_f32m8(vout0, w01, vle32_v_f32m8(imgptr0 + in_j + dilation_w, vl), vl); + vout0 = vfmacc_vf_f32m8(vout0, w02, vle32_v_f32m8(imgptr0 + in_j + dilation_w*2, vl), vl); + vout0 = vfmacc_vf_f32m8(vout0, w10, vle32_v_f32m8(imgptr1 + in_j, vl),vl); + vout0 = vfmacc_vf_f32m8(vout0, w11, vle32_v_f32m8(imgptr1 + in_j + dilation_w, vl),vl); + vout0 = vfmacc_vf_f32m8(vout0, w12, vle32_v_f32m8(imgptr1 + in_j + dilation_w*2, vl),vl); + vout0 = vfmacc_vf_f32m8(vout0, w20, vle32_v_f32m8(imgptr2 + in_j, vl), vl); + vout0 = vfmacc_vf_f32m8(vout0, w21, vle32_v_f32m8(imgptr2 + in_j + dilation_w, vl), vl); + vout0 = vfmacc_vf_f32m8(vout0, w22, vle32_v_f32m8(imgptr2 + in_j + dilation_w*2, vl), vl); if (relu) { - vbool16_t m = vmfgt_vf_f32m2_b16(vout0, 0, vl); - vout0 = vmerge_vvm_f32m2(m, vfmul_vf_f32m2(vout0, relu_coeff, vl), vout0, vl); + vbool4_t m = vmfgt_vf_f32m8_b4(vout0, 0, vl); + vout0 = vmerge_vvm_f32m8(m, vfmul_vf_f32m8(vout0, relu_coeff, vl), vout0, vl); } - vse32_v_f32m2(outptr + out_j, vout0, vl); + vse32_v_f32m8(outptr + out_j, vout0, vl); } else //stride_w == 2 && dilation_w == 1 for( ; out_j < outW1; out_j += vl, avl -= vl) { vl = vsetvl_e32m2(avl); int in_j = out_j * stride_w - pad_l; - vfloat32m2_t v00, v01, v02, v10, v11, v12, v20, v21, v22, unused; - vfloat32m2_load_deinterleave(imgptr0 + in_j, v00, v01, vl); - vfloat32m2_load_deinterleave(imgptr0 + in_j + 2, v02, unused, vl); - vfloat32m2_load_deinterleave(imgptr1 + in_j, v10, v11, vl); - vfloat32m2_load_deinterleave(imgptr1 + in_j + 2, v12, unused, vl); - vfloat32m2_load_deinterleave(imgptr2 + in_j, v20, v21, vl); - vfloat32m2_load_deinterleave(imgptr2 + in_j + 2, v22, unused, vl); - - vfloat32m2_t vout0 = vfmul_vf_f32m2(v00, w00, vl); - vfloat32m2_t vout1 = vfmul_vf_f32m2(v01, w01, vl); - vfloat32m2_t vout2 = vfmul_vf_f32m2(v02, w02, vl); - vout0 = vfadd_vf_f32m2(vout0, bias, vl); - - vout0 = vfmacc_vf_f32m2(vout0, w10, v10, vl); - vout1 = vfmacc_vf_f32m2(vout1, w11, v11, vl); - vout2 = vfmacc_vf_f32m2(vout2, w12, v12, vl); - - vout0 = vfmacc_vf_f32m2(vout0, w20, v20, vl); - vout1 = vfmacc_vf_f32m2(vout1, w21, v21, vl); - vout2 = vfmacc_vf_f32m2(vout2, w22, v22, vl); + vfloat32m2_t vout0 = vfmacc_vf_f32m2(vfmv_v_f_f32m2(bias, vl), w00, vlse32_v_f32m2(imgptr0+in_j , 8, vl), vl); + vfloat32m2_t vout1 = vfmul_vf_f32m2(vlse32_v_f32m2(imgptr0+in_j+1, 8, vl), w01, vl); + vfloat32m2_t vout2 = vfmul_vf_f32m2(vlse32_v_f32m2(imgptr0+in_j+2, 8, vl), w02, vl); + + vout0 = vfmacc_vf_f32m2(vout0, w10, vlse32_v_f32m2(imgptr1+in_j , 8, vl), vl); + vout1 = vfmacc_vf_f32m2(vout1, w11, vlse32_v_f32m2(imgptr1+in_j+1, 8, vl), vl); + vout2 = vfmacc_vf_f32m2(vout2, w12, vlse32_v_f32m2(imgptr1+in_j+2, 8, vl), vl); + + vout0 = vfmacc_vf_f32m2(vout0, w20, vlse32_v_f32m2(imgptr2+in_j , 8, vl), vl); + vout1 = vfmacc_vf_f32m2(vout1, w21, vlse32_v_f32m2(imgptr2+in_j+1, 8, vl), vl); + vout2 = vfmacc_vf_f32m2(vout2, w22, vlse32_v_f32m2(imgptr2+in_j+2, 8, vl), vl); vout0 = vfadd_vv_f32m2(vfadd_vv_f32m2(vout0, vout1, vl), vout2, vl); if (relu)