Merge pull request #22183 from zihaomu:fastConv_ARMv7_compatible

DNN: ARMv7 compatible fastConv

* support armv7 on fastConv

* remove whitespace.
This commit is contained in:
Zihao Mu 2022-07-07 18:23:08 +08:00 committed by GitHub
parent a80fcacd90
commit 139c443770
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 180 additions and 98 deletions

View File

@ -9,7 +9,7 @@
#ifndef FAST_CONV_PRAM
#define FAST_CONV_PRAM
#if CV_NEON && __aarch64__ // 32 registers.
#if CV_NEON && CV_NEON_AARCH64 // 32 registers.
#define FAST_CONV_MR 4
#define FAST_CONV_NR 28
enum { FAST_VEC_NLANES=4 };

View File

@ -158,59 +158,7 @@ void convBlock_NEON(int k, const float *a, const float *b,
float *c, int ldc, const float *bias,
float minval, float maxval, bool ifActiv)
{
#if FAST_CONV_MR == 4 && FAST_CONV_NR == 12
{
float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0;
float32x4_t c3 = vdupq_n_f32(bias[1]), c4 = c3, c5 = c3;
float32x4_t c6 = vdupq_n_f32(bias[2]), c7 = c6, c8 = c6;
float32x4_t c9 = vdupq_n_f32(bias[3]), c10 = c9, c11 = c9;
float32x4_t a0 = vdupq_n_f32(0.0f);
float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f);
for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR)
{
a0 = vld1q_f32(a);
b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
c0 = vfmaq_laneq_f32(c0, b0, a0, 0);
c1 = vfmaq_laneq_f32(c1, b1, a0, 0);
c2 = vfmaq_laneq_f32(c2, b2, a0, 0);
c3 = vfmaq_laneq_f32(c3, b0, a0, 1);
c4 = vfmaq_laneq_f32(c4, b1, a0, 1);
c5 = vfmaq_laneq_f32(c5, b2, a0, 1);
c6 = vfmaq_laneq_f32(c6, b0, a0, 2);
c7 = vfmaq_laneq_f32(c7, b1, a0, 2);
c8 = vfmaq_laneq_f32(c8, b2, a0, 2);
c9 = vfmaq_laneq_f32(c9, b0, a0, 3);
c10 = vfmaq_laneq_f32(c10, b1, a0, 3);
c11 = vfmaq_laneq_f32(c11, b2, a0, 3);
}
if (ifActiv)
{
b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval);
c0 = vminq_f32(vmaxq_f32(c0, b0), b1);
c1 = vminq_f32(vmaxq_f32(c1, b0), b1);
c2 = vminq_f32(vmaxq_f32(c2, b0), b1);
c3 = vminq_f32(vmaxq_f32(c3, b0), b1);
c4 = vminq_f32(vmaxq_f32(c4, b0), b1);
c5 = vminq_f32(vmaxq_f32(c5, b0), b1);
c6 = vminq_f32(vmaxq_f32(c6, b0), b1);
c7 = vminq_f32(vmaxq_f32(c7, b0), b1);
c8 = vminq_f32(vmaxq_f32(c8, b0), b1);
c9 = vminq_f32(vmaxq_f32(c9, b0), b1);
c10 = vminq_f32(vmaxq_f32(c10, b0), b1);
c11 = vminq_f32(vmaxq_f32(c11, b0), b1);
}
vst1q_f32(c, c0); vst1q_f32(c+4, c1); vst1q_f32(c+8, c2);
vst1q_f32(c + ldc, c3); vst1q_f32(c + ldc + 4, c4); vst1q_f32(c + ldc + 8, c5);
vst1q_f32(c + ldc*2, c6); vst1q_f32(c + ldc*2 + 4, c7); vst1q_f32(c + ldc*2 + 8, c8);
vst1q_f32(c + ldc*3, c9); vst1q_f32(c + ldc*3 + 4, c10); vst1q_f32(c + ldc*3 + 8, c11);
}
#elif FAST_CONV_MR == 4 && FAST_CONV_NR == 28
#if CV_NEON_AARCH64 && FAST_CONV_MR == 4 && FAST_CONV_NR == 28 // AARCH64
{
float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0, c24 = c0;
float32x4_t c6 = vdupq_n_f32(bias[1]), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6, c25 = c6;
@ -220,7 +168,8 @@ void convBlock_NEON(int k, const float *a, const float *b,
float32x4_t a0 = vdupq_n_f32(0.0f);
float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f);
for (int p = 0; p < k; p++, a += FAST_CONV_MR) {
for (int p = 0; p < k; p++, a += FAST_CONV_MR)
{
a0 = vld1q_f32(a);
b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
b += 12;
@ -330,11 +279,63 @@ void convBlock_NEON(int k, const float *a, const float *b,
vst1q_f32(c + ldc * 3 + 20, c23);
vst1q_f32(c + ldc * 3 + 24, c27);
}
#elif (!defined(CV_NEON_AARCH64) || !CV_NEON_AARCH64) && FAST_CONV_MR == 4 && FAST_CONV_NR == 12 // ARMv7
{
float32x4_t c0 = vdupq_n_f32(bias[0]), c1 = c0, c2 = c0;
float32x4_t c3 = vdupq_n_f32(bias[1]), c4 = c3, c5 = c3;
float32x4_t c6 = vdupq_n_f32(bias[2]), c7 = c6, c8 = c6;
float32x4_t c9 = vdupq_n_f32(bias[3]), c10 = c9, c11 = c9;
float32x2_t a0 = vdup_n_f32(0.0f), a1 = a0;
float32x4_t b0 = vdupq_n_f32(0.0f), b1 = vdupq_n_f32(0.0f), b2 = vdupq_n_f32(0.0f);
for (int p = 0; p < k; p++, a += FAST_CONV_MR, b += FAST_CONV_NR)
{
a0 = vld1_f32(a), a1 = vld1_f32(a+2);
b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8);
c0 = vmlaq_lane_f32(c0, b0, a0, 0);
c1 = vmlaq_lane_f32(c1, b1, a0, 0);
c2 = vmlaq_lane_f32(c2, b2, a0, 0);
c3 = vmlaq_lane_f32(c3, b0, a0, 1);
c4 = vmlaq_lane_f32(c4, b1, a0, 1);
c5 = vmlaq_lane_f32(c5, b2, a0, 1);
c6 = vmlaq_lane_f32(c6, b0, a1, 0);
c7 = vmlaq_lane_f32(c7, b1, a1, 0);
c8 = vmlaq_lane_f32(c8, b2, a1, 0);
c9 = vmlaq_lane_f32(c9 , b0, a1, 1);
c10 = vmlaq_lane_f32(c10, b1, a1, 1);
c11 = vmlaq_lane_f32(c11, b2, a1, 1);
}
if (ifActiv)
{
b0 = vdupq_n_f32(minval), b1 = vdupq_n_f32(maxval);
c0 = vminq_f32(vmaxq_f32(c0, b0), b1);
c1 = vminq_f32(vmaxq_f32(c1, b0), b1);
c2 = vminq_f32(vmaxq_f32(c2, b0), b1);
c3 = vminq_f32(vmaxq_f32(c3, b0), b1);
c4 = vminq_f32(vmaxq_f32(c4, b0), b1);
c5 = vminq_f32(vmaxq_f32(c5, b0), b1);
c6 = vminq_f32(vmaxq_f32(c6, b0), b1);
c7 = vminq_f32(vmaxq_f32(c7, b0), b1);
c8 = vminq_f32(vmaxq_f32(c8, b0), b1);
c9 = vminq_f32(vmaxq_f32(c9, b0), b1);
c10 = vminq_f32(vmaxq_f32(c10, b0), b1);
c11 = vminq_f32(vmaxq_f32(c11, b0), b1);
}
vst1q_f32(c, c0); vst1q_f32(c+4, c1); vst1q_f32(c+8, c2);
vst1q_f32(c + ldc, c3); vst1q_f32(c + ldc + 4, c4); vst1q_f32(c + ldc + 8, c5);
vst1q_f32(c + ldc*2, c6); vst1q_f32(c + ldc*2 + 4, c7); vst1q_f32(c + ldc*2 + 8, c8);
vst1q_f32(c + ldc*3, c9); vst1q_f32(c + ldc*3 + 4, c10); vst1q_f32(c + ldc*3 + 8, c11);
}
#else
#error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_NEON."
#endif
}
#endif
} // namespace opt_NEON

View File

@ -192,7 +192,7 @@ static void winograd_trans_input_F63(float* src, float* dst, int Channle_div4, c
float* input0 = input_buf0 + 4 * tiles * r;
// TODO! support tiles > 12
//#if (ARMV8)
//#if CV_NEON_AARCH64
// for (; ti + 11 < tiles; ti += 12)
// {
// float* out1 = out0 + line_step * ofstab0[ti * 2] + Channle_div4 * ofstab0[ti * 2 + 1] * 4;
@ -617,7 +617,6 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
float* output_ptr0 = output.ptr<float>() + bn * out_planesize * K;
// Transform Input
//int taskItemLen = C_aligned/4/ntasks;
int C_aligned_div4 = C_aligned/4;
parallel_for_(Range(0, ntasks), [&](const Range& range)
@ -1093,59 +1092,63 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
#else // ARMv7 16 registers.
// init 16 registers. FMA/load ratio = 32/12
float32x4_t r00 = vdupq_n_f32(0.0f), r01 = r00, r02 = r00, r03 = r00;
float32x4_t r04 = r00, r05 = r00, r06 = r00, r07 = r00;
float32x4_t r08 = r00, r09 = r00, r10 = r00, r11 = r00;
float32x4_t r12 = r00, r13 = r00, r14 = r00, r15 = r00;
float32x2_t q00 = vdup_n_f32(0.0f), q01 = q00, q02 = q00, q03 = q00,
q04 = q00, q05 = q00, q06 = q00, q07 = q00;
float32x4_t r04 = vdupq_n_f32(0.0f), r05 = r04, r06 = r04, r07 = r04;
float32x4_t r08 = r04, r09 = r04, r10 = r04, r11 = r04;
float32x4_t r12 = r04, r13 = r04, r14 = r04, r15 = r04;
for (; nn > 0; nn--)
{
r00 = vld1q_f32(r0), r01 = vld1q_f32(r0+4), r02 = vld1q_f32(r0+8), r03 = vld1q_f32(r0+12);
q00 = vld1_f32(r0), q01 = vld1_f32(r0+2), q02 = vld1_f32(r0+4), q03 = vld1_f32(r0+6);
q04 = vld1_f32(r0+8), q05 = vld1_f32(r0+10), q06 = vld1_f32(r0+12), q07 = vld1_f32(r0+14);
r04 = vld1q_f32(k0), r05 = vld1q_f32(k0+4), r06 = vld1q_f32(k0+8), r07 = vld1q_f32(k0+12);
r0 += 16, k0 += 16;
r08 = vfmaq_laneq_f32(r08, r04, r00, 0);
r09 = vfmaq_laneq_f32(r09, r04, r01, 0);
r10 = vfmaq_laneq_f32(r10, r04, r02, 0);
r11 = vfmaq_laneq_f32(r11, r04, r03, 0);
r08 = vmlaq_lane_f32(r08, r04, q00, 0);
r09 = vmlaq_lane_f32(r09, r04, q02, 0);
r10 = vmlaq_lane_f32(r10, r04, q04, 0);
r11 = vmlaq_lane_f32(r11, r04, q06, 0);
r08 = vfmaq_laneq_f32(r08, r05, r00, 1);
r09 = vfmaq_laneq_f32(r09, r05, r01, 1);
r10 = vfmaq_laneq_f32(r10, r05, r02, 1);
r11 = vfmaq_laneq_f32(r11, r05, r03, 1);
r08 = vmlaq_lane_f32(r08, r05, q00, 1);
r09 = vmlaq_lane_f32(r09, r05, q02, 1);
r10 = vmlaq_lane_f32(r10, r05, q04, 1);
r11 = vmlaq_lane_f32(r11, r05, q06, 1);
r08 = vfmaq_laneq_f32(r08, r06, r00, 2);
r09 = vfmaq_laneq_f32(r09, r06, r01, 2);
r10 = vfmaq_laneq_f32(r10, r06, r02, 2);
r11 = vfmaq_laneq_f32(r11, r06, r03, 2);
r08 = vmlaq_lane_f32(r08, r06, q01, 0);
r09 = vmlaq_lane_f32(r09, r06, q03, 0);
r10 = vmlaq_lane_f32(r10, r06, q05, 0);
r11 = vmlaq_lane_f32(r11, r06, q07, 0);
r08 = vfmaq_laneq_f32(r08, r07, r00, 3);
r09 = vfmaq_laneq_f32(r09, r07, r01, 3);
r10 = vfmaq_laneq_f32(r10, r07, r02, 3);
r11 = vfmaq_laneq_f32(r11, r07, r03, 3);
r08 = vmlaq_lane_f32(r08, r07, q01, 1);
r09 = vmlaq_lane_f32(r09, r07, q03, 1);
r10 = vmlaq_lane_f32(r10, r07, q05, 1);
r11 = vmlaq_lane_f32(r11, r07, q07, 1);
r00 = vld1q_f32(r0), r01 = vld1q_f32(r0+4), r02 = vld1q_f32(r0+8), r03 = vld1q_f32(r0+12);
q00 = vld1_f32(r0), q01 = vld1_f32(r0+2), q02 = vld1_f32(r0+4), q03 = vld1_f32(r0+6);
q04 = vld1_f32(r0+8), q05 = vld1_f32(r0+10), q06 = vld1_f32(r0+12), q07 = vld1_f32(r0+14);
r0 += 16;
r12 = vfmaq_laneq_f32(r12, r04, r00, 0);
r13 = vfmaq_laneq_f32(r13, r04, r01, 0);
r14 = vfmaq_laneq_f32(r14, r04, r02, 0);
r15 = vfmaq_laneq_f32(r15, r04, r03, 0);
r12 = vmlaq_lane_f32(r12, r04, q00, 0);
r13 = vmlaq_lane_f32(r13, r04, q02, 0);
r14 = vmlaq_lane_f32(r14, r04, q04, 0);
r15 = vmlaq_lane_f32(r15, r04, q06, 0);
r12 = vfmaq_laneq_f32(r12, r05, r00, 1);
r13 = vfmaq_laneq_f32(r13, r05, r01, 1);
r14 = vfmaq_laneq_f32(r14, r05, r02, 1);
r15 = vfmaq_laneq_f32(r15, r05, r03, 1);
r12 = vmlaq_lane_f32(r12, r05, q00, 1);
r13 = vmlaq_lane_f32(r13, r05, q02, 1);
r14 = vmlaq_lane_f32(r14, r05, q04, 1);
r15 = vmlaq_lane_f32(r15, r05, q06, 1);
r12 = vfmaq_laneq_f32(r12, r06, r00, 2);
r13 = vfmaq_laneq_f32(r13, r06, r01, 2);
r14 = vfmaq_laneq_f32(r14, r06, r02, 2);
r15 = vfmaq_laneq_f32(r15, r06, r03, 2);
r12 = vmlaq_lane_f32(r12, r06, q01, 0);
r13 = vmlaq_lane_f32(r13, r06, q03, 0);
r14 = vmlaq_lane_f32(r14, r06, q05, 0);
r15 = vmlaq_lane_f32(r15, r06, q07, 0);
r12 = vfmaq_laneq_f32(r12, r07, r00, 3);
r13 = vfmaq_laneq_f32(r13, r07, r01, 3);
r14 = vfmaq_laneq_f32(r14, r07, r02, 3);
r15 = vfmaq_laneq_f32(r15, r07, r03, 3);
r12 = vmlaq_lane_f32(r12, r07, q01, 1);
r13 = vmlaq_lane_f32(r13, r07, q03, 1);
r14 = vmlaq_lane_f32(r14, r07, q05, 1);
r15 = vmlaq_lane_f32(r15, r07, q07, 1);
}
vst1q_f32(output0_tm, r08), vst1q_f32(output0_tm + 4, r09), vst1q_f32(output0_tm + 8, r10), vst1q_f32(output0_tm + 12, r11);
@ -1162,7 +1165,7 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
const float* r0 = input_tm + ofstab0[ti * 2] * line_step;
const float* k0 = kernel_tm_i;
#if CV_NEON_AARCH64
// init 12 registers. FMA/load ratio = 12/8
float32x4_t r00 = vdupq_n_f32(0.0f), r01 = r00, r02 = r00, r03 = r00;
float32x4_t r08 = r00, r09 = r00, r10 = r00, r11 = r00;
@ -1194,7 +1197,42 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
r18 = vfmaq_laneq_f32(r18, r11, r02, 3);
r19 = vfmaq_laneq_f32(r19, r11, r03, 3);
}
#else
// init 12 registers. FMA/load ratio = 12/8
float32x2_t q00 = vdup_n_f32(0.0f), q01 = q00, q02 = q00, q03 = q00,
q04 = q00, q05 = q00, q06 = q00, q07 = q00;
float32x4_t r08 = vdupq_n_f32(0.0f), r09 = r08, r10 = r08, r11 = r08;
float32x4_t r16 = r08, r17 = r08, r18 = r08, r19 = r08;
for(; nn > 0; nn--)
{
q00 = vld1_f32(r0), q01 = vld1_f32(r0+2), q02 = vld1_f32(r0+4), q03 = vld1_f32(r0+6);
q04 = vld1_f32(r0+8), q05 = vld1_f32(r0+10), q06 = vld1_f32(r0+12), q07 = vld1_f32(r0+14);
r08 = vld1q_f32(k0), r09 = vld1q_f32(k0+4), r10 = vld1q_f32(k0+8), r11 = vld1q_f32(k0+12);
r0 += 16, k0 += 16;
r16 = vmlaq_lane_f32(r16, r08, q00, 0);
r17 = vmlaq_lane_f32(r17, r08, q02, 0);
r18 = vmlaq_lane_f32(r18, r08, q04, 0);
r19 = vmlaq_lane_f32(r19, r08, q06, 0);
r16 = vmlaq_lane_f32(r16, r09, q00, 1);
r17 = vmlaq_lane_f32(r17, r09, q02, 1);
r18 = vmlaq_lane_f32(r18, r09, q04, 1);
r19 = vmlaq_lane_f32(r19, r09, q06, 1);
r16 = vmlaq_lane_f32(r16, r10, q01, 0);
r17 = vmlaq_lane_f32(r17, r10, q03, 0);
r18 = vmlaq_lane_f32(r18, r10, q05, 0);
r19 = vmlaq_lane_f32(r19, r10, q07, 0);
r16 = vmlaq_lane_f32(r16, r11, q01, 1);
r17 = vmlaq_lane_f32(r17, r11, q03, 1);
r18 = vmlaq_lane_f32(r18, r11, q05, 1);
r19 = vmlaq_lane_f32(r19, r11, q07, 1);
}
#endif
vst1q_f32(output0_tm, r16), vst1q_f32(output0_tm + 4, r17), vst1q_f32(output0_tm + 8, r18), vst1q_f32(output0_tm + 12, r19);
output0_tm += 16;
}
@ -1205,6 +1243,7 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
const float* r0 = input_tm + ofstab0[ti * 2] * line_step;
const float* k0 = kernel_tm_i;
#if CV_NEON_AARCH64
// init 8 registers. FMA/load ratio = 8/6
float32x4_t r00 = vdupq_n_f32(0.0f), r01 = r00;
float32x4_t r08 = r00, r09 = r00, r10 = r00, r11 = r00;
@ -1228,7 +1267,31 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
r16 = vfmaq_laneq_f32(r16, r11, r00, 3);
r17 = vfmaq_laneq_f32(r17, r11, r01, 3);
}
#else
// init 8 registers. FMA/load ratio = 8/6
float32x2_t q00 = vdup_n_f32(0.0f), q01 = q00, q02 = q00, q03 = q00;
float32x4_t r08 = vdupq_n_f32(0.0f), r09 = r08, r10 = r08, r11 = r08;
float32x4_t r16 = r08, r17 = r08;
for(; nn > 0; nn--)
{
q00 = vld1_f32(r0), q01 = vld1_f32(r0+2), q02 = vld1_f32(r0+4), q03 = vld1_f32(r0+6);
r08 = vld1q_f32(k0), r09 = vld1q_f32(k0+4), r10 = vld1q_f32(k0+8), r11 = vld1q_f32(k0+12);
r0 += 8, k0 += 16;
r16 = vmlaq_lane_f32(r16, r08, q00, 0);
r17 = vmlaq_lane_f32(r17, r08, q02, 0);
r16 = vmlaq_lane_f32(r16, r09, q00, 1);
r17 = vmlaq_lane_f32(r17, r09, q02, 1);
r16 = vmlaq_lane_f32(r16, r10, q01, 0);
r17 = vmlaq_lane_f32(r17, r10, q03, 0);
r16 = vmlaq_lane_f32(r16, r11, q01, 1);
r17 = vmlaq_lane_f32(r17, r11, q03, 1);
}
#endif
vst1q_f32(output0_tm, r16), vst1q_f32(output0_tm + 4, r17);
output0_tm += 8;
}
@ -1239,7 +1302,8 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
const float* r0 = input_tm + ofstab0[ti * 2] * line_step;
const float* k0 = kernel_tm_i;
// init 8 registers. FMA/load ratio = 8/6
#if CV_NEON_AARCH64
// init 6 registers. FMA/load ratio = 6/5
float32x4_t r00 = vdupq_n_f32(0.0f);
float32x4_t r08 = r00, r09 = r00, r10 = r00, r11 = r00;
float32x4_t r16 = r00;
@ -1255,7 +1319,24 @@ int runWinograd63(InputArray _input, OutputArray _output, const Ptr<FastConv2d>&
r16 = vfmaq_laneq_f32(r16, r10, r00, 2);
r16 = vfmaq_laneq_f32(r16, r11, r00, 3);
}
#else
// init 6 registers. FMA/load ratio = 6/5
float32x2_t q00 = vdup_n_f32(0.0f), q01 = q00;
float32x4_t r08 = vdupq_n_f32(0.0f), r09 = r08, r10 = r08, r11 = r08;
float32x4_t r16 = r08;
for(; nn > 0; nn--)
{
q00 = vld1_f32(r0), q01 = vld1_f32(r0+2);
r08 = vld1q_f32(k0), r09 = vld1q_f32(k0+4), r10 = vld1q_f32(k0+8), r11 = vld1q_f32(k0+12);
r0 += 4, k0 += 16;
r16 = vmlaq_lane_f32(r16, r08, q00, 0);
r16 = vmlaq_lane_f32(r16, r09, q00, 1);
r16 = vmlaq_lane_f32(r16, r10, q01, 0);
r16 = vmlaq_lane_f32(r16, r11, q01, 1);
}
#endif
vst1q_f32(output0_tm, r16);
output0_tm += 4;
}