mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #10468 from fenrus75:avx512-2
* Add a 512 bit codepath to the AVX512 fastConv function this patch adds a 512 wide codepath to the fastConv() function for AVX512 use. The basic idea is to process the first N * 16 elements of the vector with avx512, and then run the rest of the vector using the traditional AVX2 codepath. * dnn: use unaligned AVX512 load (OpenCV aligns data on 32-byte boundary) * dnn: change "vecsize" condition for AVX512 * dnn: fix indentation
This commit is contained in:
parent
f06c44f1f1
commit
a75840d19c
@ -112,6 +112,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
|
|||||||
int j = 0;
|
int j = 0;
|
||||||
for( ; j <= blockSize - 4; j += 4 )
|
for( ; j <= blockSize - 4; j += 4 )
|
||||||
{
|
{
|
||||||
|
int k = 0;
|
||||||
const float* rptr = rowbuf + j*vecsize_aligned;
|
const float* rptr = rowbuf + j*vecsize_aligned;
|
||||||
|
|
||||||
__m256 vs00 = _mm256_setzero_ps(), vs01 = _mm256_setzero_ps(),
|
__m256 vs00 = _mm256_setzero_ps(), vs01 = _mm256_setzero_ps(),
|
||||||
@ -121,7 +122,65 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
|
|||||||
vs20 = _mm256_setzero_ps(), vs21 = _mm256_setzero_ps(),
|
vs20 = _mm256_setzero_ps(), vs21 = _mm256_setzero_ps(),
|
||||||
vs22 = _mm256_setzero_ps(), vs23 = _mm256_setzero_ps();
|
vs22 = _mm256_setzero_ps(), vs23 = _mm256_setzero_ps();
|
||||||
|
|
||||||
for( int k = 0; k < vecsize; k += 8, rptr += 8 )
|
#if CV_AVX512_SKX // AVX512VL is necessary to avoid register spilling
|
||||||
|
if (vecsize >= 32)
|
||||||
|
{
|
||||||
|
__m512 vs00_5 = _mm512_setzero_ps(), vs01_5 = _mm512_setzero_ps(),
|
||||||
|
vs02_5 = _mm512_setzero_ps(), vs03_5 = _mm512_setzero_ps(),
|
||||||
|
vs10_5 = _mm512_setzero_ps(), vs11_5 = _mm512_setzero_ps(),
|
||||||
|
vs12_5 = _mm512_setzero_ps(), vs13_5 = _mm512_setzero_ps(),
|
||||||
|
vs20_5 = _mm512_setzero_ps(), vs21_5 = _mm512_setzero_ps(),
|
||||||
|
vs22_5 = _mm512_setzero_ps(), vs23_5 = _mm512_setzero_ps();
|
||||||
|
|
||||||
|
for (; k <= vecsize - 16; k += 16, rptr += 16)
|
||||||
|
{
|
||||||
|
__m512 w0 = _mm512_loadu_ps(wptr0 + k);
|
||||||
|
__m512 w1 = _mm512_loadu_ps(wptr1 + k);
|
||||||
|
__m512 w2 = _mm512_loadu_ps(wptr2 + k);
|
||||||
|
__m512 r0 = _mm512_loadu_ps(rptr);
|
||||||
|
|
||||||
|
vs00_5 = _mm512_fmadd_ps(w0, r0, vs00_5);
|
||||||
|
vs10_5 = _mm512_fmadd_ps(w1, r0, vs10_5);
|
||||||
|
vs20_5 = _mm512_fmadd_ps(w2, r0, vs20_5);
|
||||||
|
|
||||||
|
r0 = _mm512_loadu_ps(rptr + vecsize_aligned);
|
||||||
|
vs01_5 = _mm512_fmadd_ps(w0, r0, vs01_5);
|
||||||
|
vs11_5 = _mm512_fmadd_ps(w1, r0, vs11_5);
|
||||||
|
vs21_5 = _mm512_fmadd_ps(w2, r0, vs21_5);
|
||||||
|
|
||||||
|
r0 = _mm512_loadu_ps(rptr + vecsize_aligned*2);
|
||||||
|
vs02_5 = _mm512_fmadd_ps(w0, r0, vs02_5);
|
||||||
|
vs12_5 = _mm512_fmadd_ps(w1, r0, vs12_5);
|
||||||
|
vs22_5 = _mm512_fmadd_ps(w2, r0, vs22_5);
|
||||||
|
|
||||||
|
r0 = _mm512_loadu_ps(rptr + vecsize_aligned*3);
|
||||||
|
vs03_5 = _mm512_fmadd_ps(w0, r0, vs03_5);
|
||||||
|
vs13_5 = _mm512_fmadd_ps(w1, r0, vs13_5);
|
||||||
|
vs23_5 = _mm512_fmadd_ps(w2, r0, vs23_5);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* now fold the 512 bit accumulator vectors into 256 bit vectors so that the AVX2 code can finish
|
||||||
|
* the tail of the vector
|
||||||
|
*/
|
||||||
|
vs00 = _mm256_add_ps( _mm512_extractf32x8_ps(vs00_5, 0), _mm512_extractf32x8_ps(vs00_5, 1));
|
||||||
|
vs10 = _mm256_add_ps( _mm512_extractf32x8_ps(vs10_5, 0), _mm512_extractf32x8_ps(vs10_5, 1));
|
||||||
|
vs20 = _mm256_add_ps( _mm512_extractf32x8_ps(vs20_5, 0), _mm512_extractf32x8_ps(vs20_5, 1));
|
||||||
|
|
||||||
|
vs01 = _mm256_add_ps( _mm512_extractf32x8_ps(vs01_5, 0), _mm512_extractf32x8_ps(vs01_5, 1));
|
||||||
|
vs11 = _mm256_add_ps( _mm512_extractf32x8_ps(vs11_5, 0), _mm512_extractf32x8_ps(vs11_5, 1));
|
||||||
|
vs21 = _mm256_add_ps( _mm512_extractf32x8_ps(vs21_5, 0), _mm512_extractf32x8_ps(vs21_5, 1));
|
||||||
|
|
||||||
|
vs02 = _mm256_add_ps( _mm512_extractf32x8_ps(vs02_5, 0), _mm512_extractf32x8_ps(vs02_5, 1));
|
||||||
|
vs12 = _mm256_add_ps( _mm512_extractf32x8_ps(vs12_5, 0), _mm512_extractf32x8_ps(vs12_5, 1));
|
||||||
|
vs22 = _mm256_add_ps( _mm512_extractf32x8_ps(vs22_5, 0), _mm512_extractf32x8_ps(vs22_5, 1));
|
||||||
|
|
||||||
|
vs03 = _mm256_add_ps( _mm512_extractf32x8_ps(vs03_5, 0), _mm512_extractf32x8_ps(vs03_5, 1));
|
||||||
|
vs13 = _mm256_add_ps( _mm512_extractf32x8_ps(vs13_5, 0), _mm512_extractf32x8_ps(vs13_5, 1));
|
||||||
|
vs23 = _mm256_add_ps( _mm512_extractf32x8_ps(vs23_5, 0), _mm512_extractf32x8_ps(vs23_5, 1));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (; k < vecsize; k += 8, rptr += 8 )
|
||||||
{
|
{
|
||||||
__m256 w0 = _mm256_load_ps(wptr0 + k);
|
__m256 w0 = _mm256_load_ps(wptr0 + k);
|
||||||
__m256 w1 = _mm256_load_ps(wptr1 + k);
|
__m256 w1 = _mm256_load_ps(wptr1 + k);
|
||||||
|
Loading…
Reference in New Issue
Block a user