From e5fb50476c14d0f096ae2b93b2b4be578710a43f Mon Sep 17 00:00:00 2001 From: HAN Liutong Date: Tue, 5 Oct 2021 23:35:00 +0800 Subject: [PATCH] Merge pull request #20521 from hanliutong:dev-rvv-multiVLEN Make the implementation of optimization in DNN adjustable to different vector sizes with RVV intrinsics. * Update fastGEMM for multi VLEN. * Update fastGEMM1T for multi VLEN. * Update fastDepthwiseConv for multi VLEN. * Update fastConv for multi VLEN. * Replace malloc with cv::AutoBuffer. --- modules/dnn/src/layers/layers_common.simd.hpp | 604 ++++++++++-------- 1 file changed, 335 insertions(+), 269 deletions(-) diff --git a/modules/dnn/src/layers/layers_common.simd.hpp b/modules/dnn/src/layers/layers_common.simd.hpp index 762e22e54d..0a077e4631 100644 --- a/modules/dnn/src/layers/layers_common.simd.hpp +++ b/modules/dnn/src/layers/layers_common.simd.hpp @@ -744,58 +744,66 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr, int ma, int na, int nb ) { int n = 0; - size_t vl = 8; - size_t mvl0 = 8; - size_t mvl1 = 8; - for( ; n < nb; n += 16 ) + int vl = vsetvlmax_e32m4(); + int mvl = vl; + for( ; n < nb; n += vl ) { - if ( n + 16 > nb) { - mvl0 = nb - n; - mvl1 = (nb - n -8) > 0 ? (nb - n -8) : 0; + if ( n + vl > nb) { + mvl = nb - n; } - for( int m = 0; m < ma; m += 4 ) + for( int m = 0; m < ma; m += 7 ) { const float* aptr0 = aptr + astep*m; const float* aptr1 = aptr + astep*std::min(m+1, ma-1); const float* aptr2 = aptr + astep*std::min(m+2, ma-1); const float* aptr3 = aptr + astep*std::min(m+3, ma-1); + const float* aptr4 = aptr + astep*std::min(m+4, ma-1); + const float* aptr5 = aptr + astep*std::min(m+5, ma-1); + const float* aptr6 = aptr + astep*std::min(m+6, ma-1); float* cptr0 = cptr + cstep*m; float* cptr1 = cptr + cstep*std::min(m+1, ma-1); float* cptr2 = cptr + cstep*std::min(m+2, ma-1); float* cptr3 = cptr + cstep*std::min(m+3, ma-1); + float* cptr4 = cptr + cstep*std::min(m+4, ma-1); + float* cptr5 = cptr + cstep*std::min(m+5, ma-1); + float* cptr6 = cptr + cstep*std::min(m+6, ma-1); - vfloat32m2_t d00 = vfmv_v_f_f32m2(0, vl), d01 = vfmv_v_f_f32m2(0, vl); - vfloat32m2_t d10 = vfmv_v_f_f32m2(0, vl), d11 = vfmv_v_f_f32m2(0, vl); - vfloat32m2_t d20 = vfmv_v_f_f32m2(0, vl), d21 = vfmv_v_f_f32m2(0, vl); - vfloat32m2_t d30 = vfmv_v_f_f32m2(0, vl), d31 = vfmv_v_f_f32m2(0, vl); + vfloat32m4_t d0 = vfmv_v_f_f32m4(0, vl); + vfloat32m4_t d1 = vfmv_v_f_f32m4(0, vl); + vfloat32m4_t d2 = vfmv_v_f_f32m4(0, vl); + vfloat32m4_t d3 = vfmv_v_f_f32m4(0, vl); + vfloat32m4_t d4 = vfmv_v_f_f32m4(0, vl); + vfloat32m4_t d5 = vfmv_v_f_f32m4(0, vl); + vfloat32m4_t d6 = vfmv_v_f_f32m4(0, vl); for( int k = 0; k < na; k++ ) { - vfloat32m2_t a0 = vfmv_v_f_f32m2(aptr0[k], vl); - vfloat32m2_t a1 = vfmv_v_f_f32m2(aptr1[k], vl); - vfloat32m2_t a2 = vfmv_v_f_f32m2(aptr2[k], vl); - vfloat32m2_t a3 = vfmv_v_f_f32m2(aptr3[k], vl); - vfloat32m2_t b0 = vle32_v_f32m2(bptr + k*bstep + n, mvl0); - vfloat32m2_t b1 = vle32_v_f32m2(bptr + k*bstep + n + 8, mvl1); - d00 = vfmacc_vv_f32m2(d00, a0, b0, mvl0); - d01 = vfmacc_vv_f32m2(d01, a0, b1, mvl1); - d10 = vfmacc_vv_f32m2(d10, a1, b0, mvl0); - d11 = vfmacc_vv_f32m2(d11, a1, b1, mvl1); - d20 = vfmacc_vv_f32m2(d20, a2, b0, mvl0); - d21 = vfmacc_vv_f32m2(d21, a2, b1, mvl1); - d30 = vfmacc_vv_f32m2(d30, a3, b0, mvl0); - d31 = vfmacc_vv_f32m2(d31, a3, b1, mvl1); + float32_t a0 = aptr0[k]; + float32_t a1 = aptr1[k]; + float32_t a2 = aptr2[k]; + float32_t a3 = aptr3[k]; + float32_t a4 = aptr4[k]; + float32_t a5 = aptr5[k]; + float32_t a6 = aptr6[k]; + + vfloat32m4_t b = vle32_v_f32m4(bptr + k*bstep + n, mvl); + d0 = vfmacc_vf_f32m4(d0, a0, b, mvl); + d1 = vfmacc_vf_f32m4(d1, a1, b, mvl); + d2 = vfmacc_vf_f32m4(d2, a2, b, mvl); + d3 = vfmacc_vf_f32m4(d3, a3, b, mvl); + d4 = vfmacc_vf_f32m4(d4, a4, b, mvl); + d5 = vfmacc_vf_f32m4(d5, a5, b, mvl); + d6 = vfmacc_vf_f32m4(d6, a6, b, mvl); } - vse32_v_f32m2(cptr0 + n, d00, mvl0); - vse32_v_f32m2(cptr1 + n, d10, mvl0); - vse32_v_f32m2(cptr2 + n, d20, mvl0); - vse32_v_f32m2(cptr3 + n, d30, mvl0); - vse32_v_f32m2(cptr0 + n + 8, d01, mvl1); - vse32_v_f32m2(cptr1 + n + 8, d11, mvl1); - vse32_v_f32m2(cptr2 + n + 8, d21, mvl1); - vse32_v_f32m2(cptr3 + n + 8, d31, mvl1); + vse32_v_f32m4(cptr0 + n, d0, mvl); + vse32_v_f32m4(cptr1 + n, d1, mvl); + vse32_v_f32m4(cptr2 + n, d2, mvl); + vse32_v_f32m4(cptr3 + n, d3, mvl); + vse32_v_f32m4(cptr4 + n, d4, mvl); + vse32_v_f32m4(cptr5 + n, d5, mvl); + vse32_v_f32m4(cptr6 + n, d6, mvl); } } } @@ -804,71 +812,108 @@ void fastGEMM1T( const float* vec, const float* weights, size_t wstep, const float* bias, float* dst, int nvecs, int vecsize ) { + int vlm2 = vsetvlmax_e32m2(); int i = 0; - size_t vl = 8; - for( ; i <= nvecs - 8; i += 8 ) + for( ; i <= nvecs - 15; i += 15 ) { const float* wptr = weights + i*wstep; - vfloat32m2_t vs0 = vfmv_v_f_f32m2(0, vl), vs1 = vfmv_v_f_f32m2(0, vl), - vs2 = vfmv_v_f_f32m2(0, vl), vs3 = vfmv_v_f_f32m2(0, vl), - vs4 = vfmv_v_f_f32m2(0, vl), vs5 = vfmv_v_f_f32m2(0, vl), - vs6 = vfmv_v_f_f32m2(0, vl), vs7 = vfmv_v_f_f32m2(0, vl); - - for( int k = 0; k < vecsize; k += 8, wptr += 8 ) + vfloat32m2_t + vs0 = vfmv_v_f_f32m2(0, vlm2), vs1 = vfmv_v_f_f32m2(0, vlm2), vs2 = vfmv_v_f_f32m2(0, vlm2), + vs3 = vfmv_v_f_f32m2(0, vlm2), vs4 = vfmv_v_f_f32m2(0, vlm2), vs5 = vfmv_v_f_f32m2(0, vlm2), + vs6 = vfmv_v_f_f32m2(0, vlm2), vs7 = vfmv_v_f_f32m2(0, vlm2), vs8 = vfmv_v_f_f32m2(0, vlm2), + vs9 = vfmv_v_f_f32m2(0, vlm2), vs10 = vfmv_v_f_f32m2(0, vlm2), vs11 = vfmv_v_f_f32m2(0, vlm2), + vs12 = vfmv_v_f_f32m2(0, vlm2), vs13 = vfmv_v_f_f32m2(0, vlm2), vs14 = vfmv_v_f_f32m2(0, vlm2); + int k = 0; + for( ; k < vecsize - vlm2; k += vlm2, wptr += vlm2 ) { - vfloat32m2_t v = vle32_v_f32m2(vec + k, vl); + vfloat32m2_t v = vle32_v_f32m2(vec + k, vlm2); - vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vl), v, vl); - vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep, vl), v, vl); - vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*2, vl), v, vl); - vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*3, vl), v, vl); - vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*4, vl), v, vl); - vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*5, vl), v, vl); - vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*6, vl), v, vl); - vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*7, vl), v, vl); + vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vlm2), v, vlm2); + vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep, vlm2), v, vlm2); + vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*2, vlm2), v, vlm2); + vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*3, vlm2), v, vlm2); + vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*4, vlm2), v, vlm2); + vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*5, vlm2), v, vlm2); + vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*6, vlm2), v, vlm2); + vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*7, vlm2), v, vlm2); + vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*8, vlm2), v, vlm2); + vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*9, vlm2), v, vlm2); + vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*10, vlm2), v, vlm2); + vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*11, vlm2), v, vlm2); + vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*12, vlm2), v, vlm2); + vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*13, vlm2), v, vlm2); + vs14 = vfmacc_vv_f32m2(vs14, vle32_v_f32m2(wptr + wstep*14, vlm2), v, vlm2); + } + int kvl = vecsize - k; + if (kvl > 0) { + vfloat32m2_t v = vle32_v_f32m2(vec + k, kvl); + vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, kvl), v, kvl); + vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep*1, kvl), v, kvl); + vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*2, kvl), v, kvl); + vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*3, kvl), v, kvl); + vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*4, kvl), v, kvl); + vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*5, kvl), v, kvl); + vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*6, kvl), v, kvl); + vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*7, kvl), v, kvl); + vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*8, kvl), v, kvl); + vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*9, kvl), v, kvl); + vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*10, kvl), v, kvl); + vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*11, kvl), v, kvl); + vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*12, kvl), v, kvl); + vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*13, kvl), v, kvl); + vs14 = vfmacc_vv_f32m2(vs14, vle32_v_f32m2(wptr + wstep*14, kvl), v, kvl); } // Calculate the sum of each vector - vfloat32m1_t zero = vfmv_v_f_f32m1(0, vl); - vfloat32m1_t temp0 = vfredsum_vs_f32m2_f32m1(temp0, vs0, zero, vl); - vfloat32m1_t temp1 = vfredsum_vs_f32m2_f32m1(temp1, vs1, zero, vl); - vfloat32m1_t temp2 = vfredsum_vs_f32m2_f32m1(temp2, vs2, zero, vl); - vfloat32m1_t temp3 = vfredsum_vs_f32m2_f32m1(temp3, vs3, zero, vl); - vfloat32m1_t temp4 = vfredsum_vs_f32m2_f32m1(temp4, vs4, zero, vl); - vfloat32m1_t temp5 = vfredsum_vs_f32m2_f32m1(temp5, vs5, zero, vl); - vfloat32m1_t temp6 = vfredsum_vs_f32m2_f32m1(temp6, vs6, zero, vl); - vfloat32m1_t temp7 = vfredsum_vs_f32m2_f32m1(temp7, vs7, zero, vl); - float32_t sum[8]; - sum[0] = vfmv_f_s_f32m1_f32(temp0); - sum[1] = vfmv_f_s_f32m1_f32(temp1); - sum[2] = vfmv_f_s_f32m1_f32(temp2); - sum[3] = vfmv_f_s_f32m1_f32(temp3); - sum[4] = vfmv_f_s_f32m1_f32(temp4); - sum[5] = vfmv_f_s_f32m1_f32(temp5); - sum[6] = vfmv_f_s_f32m1_f32(temp6); - sum[7] = vfmv_f_s_f32m1_f32(temp7); - vfloat32m2_t s0 = vfadd_vv_f32m2(vle32_v_f32m2(sum, vl), vle32_v_f32m2(bias + i, vl), vl); - vse32_v_f32m2(dst + i, s0, vl); + float32_t sum[15]; + vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm2); + sum[0] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs0, zero, vlm2)); + sum[1] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs1, zero, vlm2)); + sum[2] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs2, zero, vlm2)); + sum[3] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs3, zero, vlm2)); + sum[4] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs4, zero, vlm2)); + sum[5] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs5, zero, vlm2)); + sum[6] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs6, zero, vlm2)); + sum[7] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs7, zero, vlm2)); + sum[8] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs8, zero, vlm2)); + sum[9] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs9, zero, vlm2)); + sum[10] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs10, zero, vlm2)); + sum[11] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs11, zero, vlm2)); + sum[12] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs12, zero, vlm2)); + sum[13] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs13, zero, vlm2)); + sum[14] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs14, zero, vlm2)); + + vfloat32m4_t s0 = vfadd_vv_f32m4(vle32_v_f32m4(sum, 15), vle32_v_f32m4(bias + i, 15), 15); + vse32_v_f32m4(dst + i, s0, 15); } int mvl = nvecs - i; if (mvl > 0) { const float* wptr = weights + i*wstep; - vfloat32m2_t vs0 = vfmv_v_f_f32m2(0, vl), vs1 = vfmv_v_f_f32m2(0, vl), - vs2 = vfmv_v_f_f32m2(0, vl), vs3 = vfmv_v_f_f32m2(0, vl), - vs4 = vfmv_v_f_f32m2(0, vl), vs5 = vfmv_v_f_f32m2(0, vl), - vs6 = vfmv_v_f_f32m2(0, vl), vs7 = vfmv_v_f_f32m2(0, vl); + vfloat32m2_t + vs0 = vfmv_v_f_f32m2(0, vlm2), vs1 = vfmv_v_f_f32m2(0, vlm2), vs2 = vfmv_v_f_f32m2(0, vlm2), + vs3 = vfmv_v_f_f32m2(0, vlm2), vs4 = vfmv_v_f_f32m2(0, vlm2), vs5 = vfmv_v_f_f32m2(0, vlm2), + vs6 = vfmv_v_f_f32m2(0, vlm2), vs7 = vfmv_v_f_f32m2(0, vlm2), vs8 = vfmv_v_f_f32m2(0, vlm2), + vs9 = vfmv_v_f_f32m2(0, vlm2), vs10 = vfmv_v_f_f32m2(0, vlm2), vs11 = vfmv_v_f_f32m2(0, vlm2), + vs12 = vfmv_v_f_f32m2(0, vlm2), vs13 = vfmv_v_f_f32m2(0, vlm2); int k = 0; - for( ; k <= vecsize - 8; k += 8, wptr += 8 ) + for( ; k <= vecsize - vlm2; k += vlm2, wptr += vlm2 ) { - vfloat32m2_t v = vle32_v_f32m2(vec + k, vl); - vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vl), v, vl); - vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep*std::min(1, mvl-1), vl), v, vl); - vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*std::min(2, mvl-1), vl), v, vl); - vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*std::min(3, mvl-1), vl), v, vl); - vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*std::min(4, mvl-1), vl), v, vl); - vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*std::min(5, mvl-1), vl), v, vl); - vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*std::min(6, mvl-1), vl), v, vl); + vfloat32m2_t v = vle32_v_f32m2(vec + k, vlm2); + vs0 = vfmacc_vv_f32m2(vs0, vle32_v_f32m2(wptr, vlm2), v, vlm2); + vs1 = vfmacc_vv_f32m2(vs1, vle32_v_f32m2(wptr + wstep*std::min(1, mvl-1), vlm2), v, vlm2); + vs2 = vfmacc_vv_f32m2(vs2, vle32_v_f32m2(wptr + wstep*std::min(2, mvl-1), vlm2), v, vlm2); + vs3 = vfmacc_vv_f32m2(vs3, vle32_v_f32m2(wptr + wstep*std::min(3, mvl-1), vlm2), v, vlm2); + vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*std::min(4, mvl-1), vlm2), v, vlm2); + vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*std::min(5, mvl-1), vlm2), v, vlm2); + vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*std::min(6, mvl-1), vlm2), v, vlm2); + vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*std::min(7, mvl-1), vlm2), v, vlm2); + vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*std::min(8, mvl-1), vlm2), v, vlm2); + vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*std::min(9, mvl-1), vlm2), v, vlm2); + vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*std::min(10, mvl-1), vlm2), v, vlm2); + vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*std::min(11, mvl-1), vlm2), v, vlm2); + vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*std::min(12, mvl-1), vlm2), v, vlm2); + vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*std::min(13, mvl-1), vlm2), v, vlm2); } int kvl = vecsize - k; if (kvl > 0) { @@ -880,54 +925,47 @@ void fastGEMM1T( const float* vec, const float* weights, vs4 = vfmacc_vv_f32m2(vs4, vle32_v_f32m2(wptr + wstep*std::min(4, mvl-1), kvl), v, kvl); vs5 = vfmacc_vv_f32m2(vs5, vle32_v_f32m2(wptr + wstep*std::min(5, mvl-1), kvl), v, kvl); vs6 = vfmacc_vv_f32m2(vs6, vle32_v_f32m2(wptr + wstep*std::min(6, mvl-1), kvl), v, kvl); + vs7 = vfmacc_vv_f32m2(vs7, vle32_v_f32m2(wptr + wstep*std::min(7, mvl-1), kvl), v, kvl); + vs8 = vfmacc_vv_f32m2(vs8, vle32_v_f32m2(wptr + wstep*std::min(8, mvl-1), kvl), v, kvl); + vs9 = vfmacc_vv_f32m2(vs9, vle32_v_f32m2(wptr + wstep*std::min(9, mvl-1), kvl), v, kvl); + vs10 = vfmacc_vv_f32m2(vs10, vle32_v_f32m2(wptr + wstep*std::min(10, mvl-1), kvl), v, kvl); + vs11 = vfmacc_vv_f32m2(vs11, vle32_v_f32m2(wptr + wstep*std::min(11, mvl-1), kvl), v, kvl); + vs12 = vfmacc_vv_f32m2(vs12, vle32_v_f32m2(wptr + wstep*std::min(12, mvl-1), kvl), v, kvl); + vs13 = vfmacc_vv_f32m2(vs13, vle32_v_f32m2(wptr + wstep*std::min(13, mvl-1), kvl), v, kvl); } // Calculate the sum of each vector - vfloat32m1_t zero = vfmv_v_f_f32m1(0, vl); - vfloat32m1_t temp0 = vfmv_v_f_f32m1(0, 4), temp1 = vfmv_v_f_f32m1(0, 4), - temp2 = vfmv_v_f_f32m1(0, 4), temp3 = vfmv_v_f_f32m1(0, 4), - temp4 = vfmv_v_f_f32m1(0, 4), temp5 = vfmv_v_f_f32m1(0, 4), - temp6 = vfmv_v_f_f32m1(0, 4), temp7 = vfmv_v_f_f32m1(0, 4); - temp0 = vfredsum_vs_f32m2_f32m1(temp0, vs0, zero, vl); - temp1 = vfredsum_vs_f32m2_f32m1(temp1, vs1, zero, vl); - temp2 = vfredsum_vs_f32m2_f32m1(temp2, vs2, zero, vl); - temp3 = vfredsum_vs_f32m2_f32m1(temp3, vs3, zero, vl); - temp4 = vfredsum_vs_f32m2_f32m1(temp4, vs4, zero, vl); - temp5 = vfredsum_vs_f32m2_f32m1(temp5, vs5, zero, vl); - temp6 = vfredsum_vs_f32m2_f32m1(temp6, vs6, zero, vl); - temp7 = vfredsum_vs_f32m2_f32m1(temp7, vs7, zero, vl); + float32_t sum[14]; + vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm2); + sum[0] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs0, zero, vlm2)); + sum[1] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs1, zero, vlm2)); + sum[2] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs2, zero, vlm2)); + sum[3] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs3, zero, vlm2)); + sum[4] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs4, zero, vlm2)); + sum[5] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs5, zero, vlm2)); + sum[6] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs6, zero, vlm2)); + sum[7] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs7, zero, vlm2)); + sum[8] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs8, zero, vlm2)); + sum[9] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs9, zero, vlm2)); + sum[10] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs10, zero, vlm2)); + sum[11] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs11, zero, vlm2)); + sum[12] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs12, zero, vlm2)); + sum[13] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m2_f32m1(zero, vs13, zero, vlm2)); - float32_t sum[8]; - sum[0] = vfmv_f_s_f32m1_f32(temp0); - sum[1] = vfmv_f_s_f32m1_f32(temp1); - sum[2] = vfmv_f_s_f32m1_f32(temp2); - sum[3] = vfmv_f_s_f32m1_f32(temp3); - sum[4] = vfmv_f_s_f32m1_f32(temp4); - sum[5] = vfmv_f_s_f32m1_f32(temp5); - sum[6] = vfmv_f_s_f32m1_f32(temp6); - sum[7] = vfmv_f_s_f32m1_f32(temp7); - - vfloat32m2_t s0 = vfadd_vv_f32m2(vle32_v_f32m2(sum, mvl), vle32_v_f32m2(bias + i, mvl), mvl); - vse32_v_f32m2(dst + i, s0, mvl); + vfloat32m4_t s0 = vfadd_vv_f32m4(vle32_v_f32m4(sum, mvl), vle32_v_f32m4(bias + i, mvl), mvl); + vse32_v_f32m4(dst + i, s0, mvl); } } -enum { FASCONV_BASE_VECSZ = 4 }; // TODO: Large base size. +enum { FASCONV_BASE_VECSZ = 8 }; void fastConv( const float* weights, size_t wstep, const float* bias, const float* rowbuf, float* output, const int* outShape, int blockSize, int vecsize, int vecsize_aligned, const float* relu, bool initOutput ) { - int vl = 4; + int vl = FASCONV_BASE_VECSZ; + int vlm1Max = vsetvlmax_e32m1(); int outCn = outShape[1]; size_t outPlaneSize = outShape[2]*outShape[3]; - float r0 = 1.f, r1 = 1.f, r2 = 1.f; - vfloat32m1_t vr0 = vfmv_v_f_f32m1(1, vl), vr1 = vfmv_v_f_f32m1(1, vl), vr2 = vfmv_v_f_f32m1(1, vl); - int maskbuf[FASCONV_BASE_VECSZ] = {0}; - int rsz = blockSize % FASCONV_BASE_VECSZ; - for( int i = 0; i < rsz; i++ ) - maskbuf[FASCONV_BASE_VECSZ - i - 1] = -1; - vint32m1_t vmaskbuf = vle32_v_i32m1(maskbuf ,vl); - vbool32_t mask = vmslt_vx_i32m1_b32(vmaskbuf, 0, vl); // mask for tail // now compute dot product of the weights // and im2row-transformed part of the tensor for( int i = 0; i < outCn; i += 3 ) @@ -953,20 +991,6 @@ void fastConv( const float* weights, size_t wstep, const float* bias, } } - if( relu ) - { - r0 = relu[i]; r1 = relu[i+1]; r2 = relu[i+2]; - if( i+2 >= outCn ) - { - r2 = r1; - if( i+1 >= outCn ) - r2 = r1 = r0; - } - vr0 = vfmv_v_f_f32m1(r0, vl); - vr1 = vfmv_v_f_f32m1(r1, vl); - vr2 = vfmv_v_f_f32m1(r2, vl); - } - int j = 0; for( ; j < blockSize; j += FASCONV_BASE_VECSZ ) { @@ -983,110 +1007,152 @@ void fastConv( const float* weights, size_t wstep, const float* bias, } int k = 0; const float* rptr = rowbuf + j*vecsize_aligned; - int vlm2 = 8; - vfloat32m2_t vs00 = vfmv_v_f_f32m2(0, vlm2), vs01 = vfmv_v_f_f32m2(0, vlm2), - vs02 = vfmv_v_f_f32m2(0, vlm2), vs03 = vfmv_v_f_f32m2(0, vlm2), - vs10 = vfmv_v_f_f32m2(0, vlm2), vs11 = vfmv_v_f_f32m2(0, vlm2), - vs12 = vfmv_v_f_f32m2(0, vlm2), vs13 = vfmv_v_f_f32m2(0, vlm2), - vs20 = vfmv_v_f_f32m2(0, vlm2), vs21 = vfmv_v_f_f32m2(0, vlm2), - vs22 = vfmv_v_f_f32m2(0, vlm2), vs23 = vfmv_v_f_f32m2(0, vlm2); + int vlm1 = vsetvlmax_e32m1(); + vfloat32m1_t + vs00 = vfmv_v_f_f32m1(0, vlm1), vs10 = vfmv_v_f_f32m1(0, vlm1), vs20 = vfmv_v_f_f32m1(0, vlm1), + vs01 = vfmv_v_f_f32m1(0, vlm1), vs11 = vfmv_v_f_f32m1(0, vlm1), vs21 = vfmv_v_f_f32m1(0, vlm1), + vs02 = vfmv_v_f_f32m1(0, vlm1), vs12 = vfmv_v_f_f32m1(0, vlm1), vs22 = vfmv_v_f_f32m1(0, vlm1), + vs03 = vfmv_v_f_f32m1(0, vlm1), vs13 = vfmv_v_f_f32m1(0, vlm1), vs23 = vfmv_v_f_f32m1(0, vlm1), + vs04 = vfmv_v_f_f32m1(0, vlm1), vs14 = vfmv_v_f_f32m1(0, vlm1), vs24 = vfmv_v_f_f32m1(0, vlm1), + vs05 = vfmv_v_f_f32m1(0, vlm1), vs15 = vfmv_v_f_f32m1(0, vlm1), vs25 = vfmv_v_f_f32m1(0, vlm1), + vs06 = vfmv_v_f_f32m1(0, vlm1), vs16 = vfmv_v_f_f32m1(0, vlm1), vs26 = vfmv_v_f_f32m1(0, vlm1), + vs07 = vfmv_v_f_f32m1(0, vlm1), vs17 = vfmv_v_f_f32m1(0, vlm1), vs27 = vfmv_v_f_f32m1(0, vlm1); - for (; k < vecsize; k += 8, rptr += 8 ) + for (; k < vecsize; k += vlm1, rptr += vlm1 ) { - if (k+8 >= vecsize) { - vlm2 = vecsize - k; + if (k + vlm1 >= vecsize) { + vlm1 = vecsize - k; } - vfloat32m2_t w0 = vle32_v_f32m2(wptr0 + k, vlm2); - vfloat32m2_t w1 = vle32_v_f32m2(wptr1 + k, vlm2); - vfloat32m2_t w2 = vle32_v_f32m2(wptr2 + k, vlm2); - vfloat32m2_t r0 = vle32_v_f32m2(rptr, vlm2); + vfloat32m1_t w0 = vle32_v_f32m1(wptr0 + k, vlm1); + vfloat32m1_t w1 = vle32_v_f32m1(wptr1 + k, vlm1); + vfloat32m1_t w2 = vle32_v_f32m1(wptr2 + k, vlm1); + vfloat32m1_t r0 = vle32_v_f32m1(rptr, vlm1); - vs00 = vfmacc_vv_f32m2(vs00, w0, r0, vlm2); - vs10 = vfmacc_vv_f32m2(vs10, w1, r0, vlm2); - vs20 = vfmacc_vv_f32m2(vs20, w2, r0, vlm2); + vs00 = vfmacc_vv_f32m1(vs00, w0, r0, vlm1); + vs10 = vfmacc_vv_f32m1(vs10, w1, r0, vlm1); + vs20 = vfmacc_vv_f32m1(vs20, w2, r0, vlm1); - r0 = vle32_v_f32m2(rptr + vecsize_aligned, vlm2); - vs01 = vfmacc_vv_f32m2(vs01, w0, r0, vlm2); - vs11 = vfmacc_vv_f32m2(vs11, w1, r0, vlm2); - vs21 = vfmacc_vv_f32m2(vs21, w2, r0, vlm2); + r0 = vle32_v_f32m1(rptr + vecsize_aligned, vlm1); + vs01 = vfmacc_vv_f32m1(vs01, w0, r0, vlm1); + vs11 = vfmacc_vv_f32m1(vs11, w1, r0, vlm1); + vs21 = vfmacc_vv_f32m1(vs21, w2, r0, vlm1); - r0 = vle32_v_f32m2(rptr + vecsize_aligned*2, vlm2); - vs02 = vfmacc_vv_f32m2(vs02, w0, r0, vlm2); - vs12 = vfmacc_vv_f32m2(vs12, w1, r0, vlm2); - vs22 = vfmacc_vv_f32m2(vs22, w2, r0, vlm2); + r0 = vle32_v_f32m1(rptr + vecsize_aligned*2, vlm1); + vs02 = vfmacc_vv_f32m1(vs02, w0, r0, vlm1); + vs12 = vfmacc_vv_f32m1(vs12, w1, r0, vlm1); + vs22 = vfmacc_vv_f32m1(vs22, w2, r0, vlm1); - r0 = vle32_v_f32m2(rptr + vecsize_aligned*3, vlm2); - vs03 = vfmacc_vv_f32m2(vs03, w0, r0, vlm2); - vs13 = vfmacc_vv_f32m2(vs13, w1, r0, vlm2); - vs23 = vfmacc_vv_f32m2(vs23, w2, r0, vlm2); + r0 = vle32_v_f32m1(rptr + vecsize_aligned*3, vlm1); + vs03 = vfmacc_vv_f32m1(vs03, w0, r0, vlm1); + vs13 = vfmacc_vv_f32m1(vs13, w1, r0, vlm1); + vs23 = vfmacc_vv_f32m1(vs23, w2, r0, vlm1); + + r0 = vle32_v_f32m1(rptr + vecsize_aligned*4, vlm1); + vs04 = vfmacc_vv_f32m1(vs04, w0, r0, vlm1); + vs14 = vfmacc_vv_f32m1(vs14, w1, r0, vlm1); + vs24 = vfmacc_vv_f32m1(vs24, w2, r0, vlm1); + + r0 = vle32_v_f32m1(rptr + vecsize_aligned*5, vlm1); + vs05 = vfmacc_vv_f32m1(vs05, w0, r0, vlm1); + vs15 = vfmacc_vv_f32m1(vs15, w1, r0, vlm1); + vs25 = vfmacc_vv_f32m1(vs25, w2, r0, vlm1); + + r0 = vle32_v_f32m1(rptr + vecsize_aligned*6, vlm1); + vs06 = vfmacc_vv_f32m1(vs06, w0, r0, vlm1); + vs16 = vfmacc_vv_f32m1(vs16, w1, r0, vlm1); + vs26 = vfmacc_vv_f32m1(vs26, w2, r0, vlm1); + + r0 = vle32_v_f32m1(rptr + vecsize_aligned*7, vlm1); + vs07 = vfmacc_vv_f32m1(vs07, w0, r0, vlm1); + vs17 = vfmacc_vv_f32m1(vs17, w1, r0, vlm1); + vs27 = vfmacc_vv_f32m1(vs27, w2, r0, vlm1); } - vfloat32m1_t s0, s1, s2; + // compute sum of each vs + vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm1Max); + // vl is required here to be at least FASCONV_BASE_VECSZ, aka 8. + float32_t sum0[FASCONV_BASE_VECSZ], sum1[FASCONV_BASE_VECSZ], sum2[FASCONV_BASE_VECSZ]; + sum0[0] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs00, zero, vlm1Max)); + sum0[1] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs01, zero, vlm1Max)); + sum0[2] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs02, zero, vlm1Max)); + sum0[3] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs03, zero, vlm1Max)); + sum0[4] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs04, zero, vlm1Max)); + sum0[5] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs05, zero, vlm1Max)); + sum0[6] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs06, zero, vlm1Max)); + sum0[7] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs07, zero, vlm1Max)); + sum1[0] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs10, zero, vlm1Max)); + sum1[1] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs11, zero, vlm1Max)); + sum1[2] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs12, zero, vlm1Max)); + sum1[3] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs13, zero, vlm1Max)); + sum1[4] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs14, zero, vlm1Max)); + sum1[5] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs15, zero, vlm1Max)); + sum1[6] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs16, zero, vlm1Max)); + sum1[7] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs17, zero, vlm1Max)); + sum2[0] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs20, zero, vlm1Max)); + sum2[1] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs21, zero, vlm1Max)); + sum2[2] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs22, zero, vlm1Max)); + sum2[3] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs23, zero, vlm1Max)); + sum2[4] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs24, zero, vlm1Max)); + sum2[5] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs25, zero, vlm1Max)); + sum2[6] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs26, zero, vlm1Max)); + sum2[7] = vfmv_f_s_f32m1_f32(vfredsum_vs_f32m1_f32m1(zero, vs27, zero, vlm1Max)); + + // if VLEN = 128, so LMUL = 2 for vl = 8. + // otherwise, VLEN >=256, we only use fist 8 element of the vReg. + vfloat32m2_t s0, s1, s2; if( initOutput ) { - s0 = vfmv_v_f_f32m1(bias0, vl); - s1 = vfmv_v_f_f32m1(bias1, vl); - s2 = vfmv_v_f_f32m1(bias2, vl); + s0 = vfmv_v_f_f32m2(bias0, vl); + s1 = vfmv_v_f_f32m2(bias1, vl); + s2 = vfmv_v_f_f32m2(bias2, vl); } else { - s0 = vle32_v_f32m1(outptr0 + j, vl); - s1 = vle32_v_f32m1(outptr1 + j, vl); - s2 = vle32_v_f32m1(outptr2 + j, vl); + s0 = vle32_v_f32m2(outptr0 + j, vl); + s1 = vle32_v_f32m2(outptr1 + j, vl); + s2 = vle32_v_f32m2(outptr2 + j, vl); } - // compute sum of each vs - vfloat32m1_t zero = vfmv_v_f_f32m1(0, vl); - vfloat32m1_t temp00 = vfredsum_vs_f32m2_f32m1(temp00, vs00, zero, 8); - vfloat32m1_t temp01 = vfredsum_vs_f32m2_f32m1(temp01, vs01, zero, 8); - vfloat32m1_t temp02 = vfredsum_vs_f32m2_f32m1(temp02, vs02, zero, 8); - vfloat32m1_t temp03 = vfredsum_vs_f32m2_f32m1(temp03, vs03, zero, 8); - vfloat32m1_t temp10 = vfredsum_vs_f32m2_f32m1(temp10, vs10, zero, 8); - vfloat32m1_t temp11 = vfredsum_vs_f32m2_f32m1(temp11, vs11, zero, 8); - vfloat32m1_t temp12 = vfredsum_vs_f32m2_f32m1(temp12, vs12, zero, 8); - vfloat32m1_t temp13 = vfredsum_vs_f32m2_f32m1(temp13, vs13, zero, 8); - vfloat32m1_t temp20 = vfredsum_vs_f32m2_f32m1(temp20, vs20, zero, 8); - vfloat32m1_t temp21 = vfredsum_vs_f32m2_f32m1(temp21, vs21, zero, 8); - vfloat32m1_t temp22 = vfredsum_vs_f32m2_f32m1(temp22, vs22, zero, 8); - vfloat32m1_t temp23 = vfredsum_vs_f32m2_f32m1(temp23, vs23, zero, 8); - float32_t sum0[4], sum1[4], sum2[4]; - sum0[0] = vfmv_f_s_f32m1_f32(temp00); - sum0[1] = vfmv_f_s_f32m1_f32(temp01); - sum0[2] = vfmv_f_s_f32m1_f32(temp02); - sum0[3] = vfmv_f_s_f32m1_f32(temp03); - sum1[0] = vfmv_f_s_f32m1_f32(temp10); - sum1[1] = vfmv_f_s_f32m1_f32(temp11); - sum1[2] = vfmv_f_s_f32m1_f32(temp12); - sum1[3] = vfmv_f_s_f32m1_f32(temp13); - sum2[0] = vfmv_f_s_f32m1_f32(temp20); - sum2[1] = vfmv_f_s_f32m1_f32(temp21); - sum2[2] = vfmv_f_s_f32m1_f32(temp22); - sum2[3] = vfmv_f_s_f32m1_f32(temp23); - - s0 = vfadd_vv_f32m1(vle32_v_f32m1(sum0, vl), s0, vl); - s1 = vfadd_vv_f32m1(vle32_v_f32m1(sum1, vl), s1, vl); - s2 = vfadd_vv_f32m1(vle32_v_f32m1(sum2, vl), s2, vl); - + s0 = vfadd_vv_f32m2(vle32_v_f32m2(sum0, vl), s0, vl); + s1 = vfadd_vv_f32m2(vle32_v_f32m2(sum1, vl), s1, vl); + s2 = vfadd_vv_f32m2(vle32_v_f32m2(sum2, vl), s2, vl); if( relu ) { - vbool32_t m0 = vmfgt_vf_f32m1_b32(s0, 0, vl); - vbool32_t m1 = vmfgt_vf_f32m1_b32(s1, 0, vl); - vbool32_t m2 = vmfgt_vf_f32m1_b32(s2, 0, vl); - s0 = vmerge_vvm_f32m1(m0, vfmul_vv_f32m1(s0, vr0, vl), s0, vl); - s1 = vmerge_vvm_f32m1(m1, vfmul_vv_f32m1(s1, vr1, vl), s1, vl); - s2 = vmerge_vvm_f32m1(m2, vfmul_vv_f32m1(s2, vr2, vl), s2, vl); + vfloat32m2_t vr0 = vfmv_v_f_f32m2(1, vl), vr1 = vfmv_v_f_f32m2(1, vl), vr2 = vfmv_v_f_f32m2(1, vl); + float r0 = relu[i], r1 = relu[i+1], r2 = relu[i+2]; + if( i+2 >= outCn ) + { + r2 = r1; + if( i+1 >= outCn ) + r2 = r1 = r0; + } + vr0 = vfmv_v_f_f32m2(r0, vl); + vr1 = vfmv_v_f_f32m2(r1, vl); + vr2 = vfmv_v_f_f32m2(r2, vl); + vbool16_t m0 = vmfgt_vf_f32m2_b16(s0, 0, vl); + vbool16_t m1 = vmfgt_vf_f32m2_b16(s1, 0, vl); + vbool16_t m2 = vmfgt_vf_f32m2_b16(s2, 0, vl); + s0 = vmerge_vvm_f32m2(m0, vfmul_vv_f32m2(s0, vr0, vl), s0, vl); + s1 = vmerge_vvm_f32m2(m1, vfmul_vv_f32m2(s1, vr1, vl), s1, vl); + s2 = vmerge_vvm_f32m2(m2, vfmul_vv_f32m2(s2, vr2, vl), s2, vl); } if( tail ) { - s0 = vmerge_vvm_f32m1(mask, vle32_v_f32m1(outptr0 + j, vl), s0, vl); - s1 = vmerge_vvm_f32m1(mask, vle32_v_f32m1(outptr1 + j, vl), s1, vl); - s2 = vmerge_vvm_f32m1(mask, vle32_v_f32m1(outptr2 + j, vl), s2, vl); + int maskbuf[FASCONV_BASE_VECSZ] = {0}; + int rsz = blockSize % FASCONV_BASE_VECSZ; + for( int i = 0; i < rsz; i++ ) + maskbuf[FASCONV_BASE_VECSZ - i - 1] = -1; + vint32m2_t vmaskbuf = vle32_v_i32m2(maskbuf ,vl); + vbool16_t mask = vmslt_vx_i32m2_b16(vmaskbuf, 0, vl); // mask for tail + s0 = vmerge_vvm_f32m2(mask, vle32_v_f32m2(outptr0 + j, vl), s0, vl); + s1 = vmerge_vvm_f32m2(mask, vle32_v_f32m2(outptr1 + j, vl), s1, vl); + s2 = vmerge_vvm_f32m2(mask, vle32_v_f32m2(outptr2 + j, vl), s2, vl); } - vse32_v_f32m1(outptr0 + j, s0, vl); - vse32_v_f32m1(outptr1 + j, s1, vl); - vse32_v_f32m1(outptr2 + j, s2, vl); + vse32_v_f32m2(outptr0 + j, s0, vl); + vse32_v_f32m2(outptr1 + j, s1, vl); + vse32_v_f32m2(outptr2 + j, s2, vl); } } } @@ -1097,23 +1163,27 @@ Example for load_deinterleave: output: a = {1, 3, 5, 7, 9, 11, 13, 15} output: b = {2, 4, 6, 8,10, 12, 14, 16} */ -static inline void vfloat32m2_load_deinterleave(const float* ptr, vfloat32m2_t& a, vfloat32m2_t& b) +static inline void vfloat32m2_load_deinterleave(const float* ptr, vfloat32m2_t& a, vfloat32m2_t& b, int vl) { - int vl = 8; - uint32_t masks[] = {1,1,1,1,0,0,0,0}; - vuint32m2_t vm = vle32_v_u32m2(masks,vl); - vbool16_t mask01 = vmseq_vx_u32m2_b16 (vm, 0, vl); - vbool16_t mask10 = vmseq_vx_u32m2_b16 (vm, 1, vl); - vfloat32m2_t ta = vle32_v_f32m2(ptr, vl), tb = vle32_v_f32m2(ptr+8, vl); - uint idx[] = {0,2,4,6,1,3,5,7}; - uint idxa[] = {0,0,0,0,0,1,2,3}, idxb[] = {4,5,6,7,0,0,0,0}; - vuint32m2_t vidxa = vle32_v_u32m2(idxa, 8), vidxb = vle32_v_u32m2(idxb, 8); - vuint32m2_t vidx = vle32_v_u32m2(idx, 8); - vfloat32m2_t high = vfmv_v_f_f32m2(0, 8), low = vfmv_v_f_f32m2(0, 8); - high = vrgather_vv_f32m2(ta, vidx, 8); - low = vrgather_vv_f32m2(tb, vidx, 8); - a = vrgather_vv_f32m2_m(mask01, high, low, vidxa, 8); - b = vrgather_vv_f32m2_m(mask10, low, high, vidxb, 8); + vuint64m4_t mask = vmv_v_x_u64m4(1,vl*2); + vuint32m4_t mask_re = vreinterpret_v_u64m4_u32m4(mask); + vbool8_t mask0 = vmseq_vx_u32m4_b8 (mask_re, 1, vl*2); + vbool8_t mask1 = vmseq_vx_u32m4_b8 (mask_re, 0, vl*2); + vfloat32m4_t tempa = vundefined_f32m4(), tempb = vundefined_f32m4(); + vfloat32m4_t vw = vle32_v_f32m4(ptr, vl*2); + tempa = vcompress_vm_f32m4(mask0, tempa, vw, vl*2); + tempb = vcompress_vm_f32m4(mask1, tempb, vw, vl*2); + /* The following instructions have not to be supported by the GNU toolchain. + So we temporarily use store and load instead. + // a = vlmul_trunc_v_f32m4_f32m2(tempa); + // b = vlmul_trunc_v_f32m4_f32m2(tempb); + */ + cv::AutoBuffer cvBuffer(sizeof(float32_t)*vl*2); + float* buffer = (float*)cvBuffer.data(); + vse32_v_f32m4(buffer, tempa, vl); + a = vle32_v_f32m2(buffer, vl); + vse32_v_f32m4(buffer, tempb, vl); + b = vle32_v_f32m2(buffer, vl); } void fastDepthwiseConv( const float* wptr, @@ -1127,7 +1197,7 @@ void fastDepthwiseConv( const float* wptr, float* outptr_, int out_d, int outH, int outW ) { - int vl = 8; + int vl = vsetvlmax_e32m2(); const float w00_ = wptr[0], w01_ = wptr[1], w02_ = wptr[2], w10 = wptr[3], w11 = wptr[4], w12 = wptr[5], w20_ = wptr[6], w21_ = wptr[7], w22_ = wptr[8]; @@ -1166,17 +1236,11 @@ void fastDepthwiseConv( const float* wptr, if (stride_w == 1 || (stride_w == 2 && dilation_w == 1)) { - const int VECSZ = 8; - vfloat32m2_t vw00 = vfmv_v_f_f32m2(w00, vl), vw01 = vfmv_v_f_f32m2(w01, vl), vw02 = vfmv_v_f_f32m2(w02, vl), - vw10 = vfmv_v_f_f32m2(w10, vl), vw11 = vfmv_v_f_f32m2(w11, vl), vw12 = vfmv_v_f_f32m2(w12, vl), - vw20 = vfmv_v_f_f32m2(w20, vl), vw21 = vfmv_v_f_f32m2(w21, vl), vw22 = vfmv_v_f_f32m2(w22, vl); - vfloat32m2_t vbias = vfmv_v_f_f32m2(bias, vl), vrc = vfmv_v_f_f32m2(relu_coeff, vl); - if( stride_w == 1 ) - for( ; out_j < outW1; out_j += VECSZ ) + for( ; out_j < outW1; out_j += vl ) { - if (out_j + VECSZ > outW1 && out_j > pad_l) - out_j = outW1 - VECSZ; + if (out_j + vl > outW1) + vl = outW1 - out_j; int in_j = out_j * stride_w - pad_l; vfloat32m2_t v00 = vle32_v_f32m2(imgptr0 + in_j, vl), v01 = vle32_v_f32m2(imgptr0 + in_j + dilation_w, vl), @@ -1188,57 +1252,59 @@ void fastDepthwiseConv( const float* wptr, v21 = vle32_v_f32m2(imgptr2 + in_j + dilation_w, vl), v22 = vle32_v_f32m2(imgptr2 + in_j + dilation_w*2, vl); - vfloat32m2_t vout0 = vfmacc_vv_f32m2(vbias, v00, vw00, vl); - vfloat32m2_t vout1 = vfmul_vv_f32m2(v01, vw01, vl); - vfloat32m2_t vout2 = vfmul_vv_f32m2(v02, vw02, vl); + vfloat32m2_t vout0 = vfmul_vf_f32m2(v00, w00, vl); + vfloat32m2_t vout1 = vfmul_vf_f32m2(v01, w01, vl); + vfloat32m2_t vout2 = vfmul_vf_f32m2(v02, w02, vl); + vout0 = vfadd_vf_f32m2(vout0, bias, vl); - vout0 = vfmacc_vv_f32m2(vout0, v10, vw10, vl); - vout1 = vfmacc_vv_f32m2(vout1, v11, vw11, vl); - vout2 = vfmacc_vv_f32m2(vout2, v12, vw12, vl); + vout0 = vfmacc_vf_f32m2(vout0, w10, v10, vl); + vout1 = vfmacc_vf_f32m2(vout1, w11, v11, vl); + vout2 = vfmacc_vf_f32m2(vout2, w12, v12, vl); - vout0 = vfmacc_vv_f32m2(vout0, v20, vw20, vl); - vout1 = vfmacc_vv_f32m2(vout1, v21, vw21, vl); - vout2 = vfmacc_vv_f32m2(vout2, v22, vw22, vl); + vout0 = vfmacc_vf_f32m2(vout0, w20, v20, vl); + vout1 = vfmacc_vf_f32m2(vout1, w21, v21, vl); + vout2 = vfmacc_vf_f32m2(vout2, w22, v22, vl); vout0 = vfadd_vv_f32m2(vfadd_vv_f32m2(vout0, vout1, vl), vout2, vl); if (relu) { vbool16_t m = vmfgt_vf_f32m2_b16(vout0, 0, vl); - vout0 = vmerge_vvm_f32m2(m, vfmul_vv_f32m2(vout0, vrc, vl), vout0, vl); + vout0 = vmerge_vvm_f32m2(m, vfmul_vf_f32m2(vout0, relu_coeff, vl), vout0, vl); } vse32_v_f32m2(outptr + out_j, vout0, vl); } - else - for( ; out_j < outW1; out_j += VECSZ ) + else //stride_w == 2 && dilation_w == 1 + for( ; out_j < outW1; out_j += vl ) { - if (out_j + VECSZ > outW1 && out_j > pad_l) - out_j = outW1 - VECSZ; + if (out_j + vl > outW1) + vl = outW1 - out_j; int in_j = out_j * stride_w - pad_l; vfloat32m2_t v00, v01, v02, v10, v11, v12, v20, v21, v22, unused; - vfloat32m2_load_deinterleave(imgptr0 + in_j, v00, v01); - vfloat32m2_load_deinterleave(imgptr0 + in_j + 2, v02, unused); - vfloat32m2_load_deinterleave(imgptr1 + in_j, v10, v11); - vfloat32m2_load_deinterleave(imgptr1 + in_j + 2, v12, unused); - vfloat32m2_load_deinterleave(imgptr2 + in_j, v20, v21); - vfloat32m2_load_deinterleave(imgptr2 + in_j + 2, v22, unused); + vfloat32m2_load_deinterleave(imgptr0 + in_j, v00, v01, vl); + vfloat32m2_load_deinterleave(imgptr0 + in_j + 2, v02, unused, vl); + vfloat32m2_load_deinterleave(imgptr1 + in_j, v10, v11, vl); + vfloat32m2_load_deinterleave(imgptr1 + in_j + 2, v12, unused, vl); + vfloat32m2_load_deinterleave(imgptr2 + in_j, v20, v21, vl); + vfloat32m2_load_deinterleave(imgptr2 + in_j + 2, v22, unused, vl); - vfloat32m2_t vout0 = vfmacc_vv_f32m2(vbias, v00, vw00, vl); - vfloat32m2_t vout1 = vfmul_vv_f32m2(v01, vw01, vl); - vfloat32m2_t vout2 = vfmul_vv_f32m2(v02, vw02, vl); + vfloat32m2_t vout0 = vfmul_vf_f32m2(v00, w00, vl); + vfloat32m2_t vout1 = vfmul_vf_f32m2(v01, w01, vl); + vfloat32m2_t vout2 = vfmul_vf_f32m2(v02, w02, vl); + vout0 = vfadd_vf_f32m2(vout0, bias, vl); - vout0 = vfmacc_vv_f32m2(vout0, v10, vw10, vl); - vout1 = vfmacc_vv_f32m2(vout1, v11, vw11, vl); - vout2 = vfmacc_vv_f32m2(vout2, v12, vw12, vl); + vout0 = vfmacc_vf_f32m2(vout0, w10, v10, vl); + vout1 = vfmacc_vf_f32m2(vout1, w11, v11, vl); + vout2 = vfmacc_vf_f32m2(vout2, w12, v12, vl); - vout0 = vfmacc_vv_f32m2(vout0, v20, vw20, vl); - vout1 = vfmacc_vv_f32m2(vout1, v21, vw21, vl); - vout2 = vfmacc_vv_f32m2(vout2, v22, vw22, vl); + vout0 = vfmacc_vf_f32m2(vout0, w20, v20, vl); + vout1 = vfmacc_vf_f32m2(vout1, w21, v21, vl); + vout2 = vfmacc_vf_f32m2(vout2, w22, v22, vl); vout0 = vfadd_vv_f32m2(vfadd_vv_f32m2(vout0, vout1, vl), vout2, vl); if (relu) { vbool16_t m = vmfgt_vf_f32m2_b16(vout0, 0, vl); - vout0 = vmerge_vvm_f32m2(m, vfmul_vv_f32m2(vout0, vrc, vl), vout0, vl); + vout0 = vmerge_vvm_f32m2(m, vfmul_vf_f32m2(vout0, relu_coeff, vl), vout0, vl); } vse32_v_f32m2(outptr + out_j, vout0, vl); }