mirror of
https://github.com/opencv/opencv.git
synced 2025-06-29 08:00:57 +08:00
Provide a few AVX512 optimized functions for the DNN module
This patch adds AVX512 optimized fastConv as well as the hookups needed to get these called in the convolution_layer. AVX512 fastConv is code-identical on a C level to the AVX2 one, but is measurably faster due to AVX512 having more registers available to cache results in. Signed-off-by: Arjan van de Ven <arjan@linux.intel.com>
This commit is contained in:
parent
fc8e848a54
commit
2938860b3f
@ -13,7 +13,7 @@ endif()
|
|||||||
|
|
||||||
set(the_description "Deep neural network module. It allows to load models from different frameworks and to make forward pass")
|
set(the_description "Deep neural network module. It allows to load models from different frameworks and to make forward pass")
|
||||||
|
|
||||||
ocv_add_dispatched_file("layers/layers_common" AVX AVX2)
|
ocv_add_dispatched_file("layers/layers_common" AVX AVX2 AVX512)
|
||||||
|
|
||||||
ocv_add_module(dnn opencv_core opencv_imgproc WRAP python matlab java js)
|
ocv_add_module(dnn opencv_core opencv_imgproc WRAP python matlab java js)
|
||||||
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wno-shadow -Wno-parentheses -Wmaybe-uninitialized -Wsign-promo
|
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wno-shadow -Wno-parentheses -Wmaybe-uninitialized -Wsign-promo
|
||||||
|
@ -345,10 +345,11 @@ public:
|
|||||||
bool is1x1_;
|
bool is1x1_;
|
||||||
bool useAVX;
|
bool useAVX;
|
||||||
bool useAVX2;
|
bool useAVX2;
|
||||||
|
bool useAVX512;
|
||||||
|
|
||||||
ParallelConv()
|
ParallelConv()
|
||||||
: input_(0), weights_(0), output_(0), ngroups_(0), nstripes_(0),
|
: input_(0), weights_(0), output_(0), ngroups_(0), nstripes_(0),
|
||||||
biasvec_(0), reluslope_(0), activ_(0), is1x1_(false), useAVX(false), useAVX2(false)
|
biasvec_(0), reluslope_(0), activ_(0), is1x1_(false), useAVX(false), useAVX2(false), useAVX512(false)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
static void run( const Mat& input, Mat& output, const Mat& weights,
|
static void run( const Mat& input, Mat& output, const Mat& weights,
|
||||||
@ -383,6 +384,7 @@ public:
|
|||||||
p.is1x1_ = kernel == Size(0,0) && pad == Size(0, 0);
|
p.is1x1_ = kernel == Size(0,0) && pad == Size(0, 0);
|
||||||
p.useAVX = checkHardwareSupport(CPU_AVX);
|
p.useAVX = checkHardwareSupport(CPU_AVX);
|
||||||
p.useAVX2 = checkHardwareSupport(CPU_AVX2);
|
p.useAVX2 = checkHardwareSupport(CPU_AVX2);
|
||||||
|
p.useAVX512 = checkHardwareSupport(CPU_AVX_512DQ);
|
||||||
|
|
||||||
int ncn = std::min(inpCn, (int)BLK_SIZE_CN);
|
int ncn = std::min(inpCn, (int)BLK_SIZE_CN);
|
||||||
p.ofstab_.resize(kernel.width*kernel.height*ncn);
|
p.ofstab_.resize(kernel.width*kernel.height*ncn);
|
||||||
@ -562,6 +564,13 @@ public:
|
|||||||
// now compute dot product of the weights
|
// now compute dot product of the weights
|
||||||
// and im2row-transformed part of the tensor
|
// and im2row-transformed part of the tensor
|
||||||
int bsz = ofs1 - ofs0;
|
int bsz = ofs1 - ofs0;
|
||||||
|
#if CV_TRY_AVX512
|
||||||
|
/* AVX512 convolution requires an alignment of 16, and ROI is only there for larger vector sizes */
|
||||||
|
if(useAVX512)
|
||||||
|
opt_AVX512::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
||||||
|
outShape, bsz, vsz, vsz_a, relu, cn0 == 0);
|
||||||
|
else
|
||||||
|
#endif
|
||||||
#if CV_TRY_AVX2
|
#if CV_TRY_AVX2
|
||||||
if(useAVX2)
|
if(useAVX2)
|
||||||
opt_AVX2::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
opt_AVX2::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
||||||
@ -1093,6 +1102,7 @@ public:
|
|||||||
nstripes_ = nstripes;
|
nstripes_ = nstripes;
|
||||||
useAVX = checkHardwareSupport(CPU_AVX);
|
useAVX = checkHardwareSupport(CPU_AVX);
|
||||||
useAVX2 = checkHardwareSupport(CPU_AVX2);
|
useAVX2 = checkHardwareSupport(CPU_AVX2);
|
||||||
|
useAVX512 = checkHardwareSupport(CPU_AVX_512DQ);
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator()(const Range& range_) const
|
void operator()(const Range& range_) const
|
||||||
@ -1110,6 +1120,11 @@ public:
|
|||||||
size_t bstep = b_->step1();
|
size_t bstep = b_->step1();
|
||||||
size_t cstep = c_->step1();
|
size_t cstep = c_->step1();
|
||||||
|
|
||||||
|
#if CV_TRY_AVX512
|
||||||
|
if( useAVX512 )
|
||||||
|
opt_AVX512::fastGEMM( aptr, astep, bptr, bstep, cptr, cstep, mmax, kmax, nmax );
|
||||||
|
else
|
||||||
|
#endif
|
||||||
#if CV_TRY_AVX2
|
#if CV_TRY_AVX2
|
||||||
if( useAVX2 )
|
if( useAVX2 )
|
||||||
opt_AVX2::fastGEMM( aptr, astep, bptr, bstep, cptr, cstep, mmax, kmax, nmax );
|
opt_AVX2::fastGEMM( aptr, astep, bptr, bstep, cptr, cstep, mmax, kmax, nmax );
|
||||||
@ -1214,6 +1229,7 @@ public:
|
|||||||
int nstripes_;
|
int nstripes_;
|
||||||
bool useAVX;
|
bool useAVX;
|
||||||
bool useAVX2;
|
bool useAVX2;
|
||||||
|
bool useAVX512;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Col2ImInvoker : public cv::ParallelLoopBody
|
class Col2ImInvoker : public cv::ParallelLoopBody
|
||||||
|
@ -139,7 +139,7 @@ public:
|
|||||||
class FullyConnected : public ParallelLoopBody
|
class FullyConnected : public ParallelLoopBody
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FullyConnected() : srcMat(0), weights(0), biasMat(0), activ(0), dstMat(0), nstripes(0), useAVX(false), useAVX2(false) {}
|
FullyConnected() : srcMat(0), weights(0), biasMat(0), activ(0), dstMat(0), nstripes(0), useAVX(false), useAVX2(false), useAVX512(false) {}
|
||||||
|
|
||||||
static void run(const Mat& srcMat, const Mat& weights, const Mat& biasMat,
|
static void run(const Mat& srcMat, const Mat& weights, const Mat& biasMat,
|
||||||
Mat& dstMat, const ActivationLayer* activ, int nstripes)
|
Mat& dstMat, const ActivationLayer* activ, int nstripes)
|
||||||
@ -161,6 +161,7 @@ public:
|
|||||||
p.activ = activ;
|
p.activ = activ;
|
||||||
p.useAVX = checkHardwareSupport(CPU_AVX);
|
p.useAVX = checkHardwareSupport(CPU_AVX);
|
||||||
p.useAVX2 = checkHardwareSupport(CPU_AVX2);
|
p.useAVX2 = checkHardwareSupport(CPU_AVX2);
|
||||||
|
p.useAVX512 = checkHardwareSupport(CPU_AVX_512DQ);
|
||||||
|
|
||||||
parallel_for_(Range(0, nstripes), p, nstripes);
|
parallel_for_(Range(0, nstripes), p, nstripes);
|
||||||
}
|
}
|
||||||
@ -195,6 +196,11 @@ public:
|
|||||||
|
|
||||||
memcpy(sptr, sptr_, vecsize*sizeof(sptr[0]));
|
memcpy(sptr, sptr_, vecsize*sizeof(sptr[0]));
|
||||||
|
|
||||||
|
#if CV_TRY_AVX512
|
||||||
|
if( useAVX512 )
|
||||||
|
opt_AVX512::fastGEMM1T( sptr, wptr, wstep, biasptr, dptr, nw, vecsize);
|
||||||
|
else
|
||||||
|
#endif
|
||||||
#if CV_TRY_AVX2
|
#if CV_TRY_AVX2
|
||||||
if( useAVX2 )
|
if( useAVX2 )
|
||||||
opt_AVX2::fastGEMM1T( sptr, wptr, wstep, biasptr, dptr, nw, vecsize);
|
opt_AVX2::fastGEMM1T( sptr, wptr, wstep, biasptr, dptr, nw, vecsize);
|
||||||
@ -255,6 +261,7 @@ public:
|
|||||||
int nstripes;
|
int nstripes;
|
||||||
bool useAVX;
|
bool useAVX;
|
||||||
bool useAVX2;
|
bool useAVX2;
|
||||||
|
bool useAVX512;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef HAVE_OPENCL
|
#ifdef HAVE_OPENCL
|
||||||
|
@ -72,7 +72,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
|
|||||||
int outCn = outShape[1];
|
int outCn = outShape[1];
|
||||||
size_t outPlaneSize = outShape[2]*outShape[3];
|
size_t outPlaneSize = outShape[2]*outShape[3];
|
||||||
float r0 = 1.f, r1 = 1.f, r2 = 1.f;
|
float r0 = 1.f, r1 = 1.f, r2 = 1.f;
|
||||||
__m256 vr0 = _mm256_set1_ps(1.f), vr1 = vr0, vr2 = vr0, z = _mm256_setzero_ps();
|
__m128 vr0 = _mm_set1_ps(1.f), vr1 = vr0, vr2 = vr0, z = _mm_setzero_ps();
|
||||||
|
|
||||||
// now compute dot product of the weights
|
// now compute dot product of the weights
|
||||||
// and im2row-transformed part of the tensor
|
// and im2row-transformed part of the tensor
|
||||||
@ -104,9 +104,9 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
|
|||||||
r0 = relu[i];
|
r0 = relu[i];
|
||||||
r1 = relu[i+1];
|
r1 = relu[i+1];
|
||||||
r2 = relu[i+2];
|
r2 = relu[i+2];
|
||||||
vr0 = _mm256_set1_ps(r0);
|
vr0 = _mm_set1_ps(r0);
|
||||||
vr1 = _mm256_set1_ps(r1);
|
vr1 = _mm_set1_ps(r1);
|
||||||
vr2 = _mm256_set1_ps(r2);
|
vr2 = _mm_set1_ps(r2);
|
||||||
}
|
}
|
||||||
|
|
||||||
int j = 0;
|
int j = 0;
|
||||||
@ -156,38 +156,38 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
|
|||||||
t1 = _mm256_add_ps(t1, _mm256_permute2f128_ps(t1, t1, 1));
|
t1 = _mm256_add_ps(t1, _mm256_permute2f128_ps(t1, t1, 1));
|
||||||
t2 = _mm256_add_ps(t2, _mm256_permute2f128_ps(t2, t2, 1));
|
t2 = _mm256_add_ps(t2, _mm256_permute2f128_ps(t2, t2, 1));
|
||||||
|
|
||||||
__m256 s0, s1, s2;
|
__m128 s0, s1, s2;
|
||||||
|
|
||||||
if( initOutput )
|
if( initOutput )
|
||||||
{
|
{
|
||||||
s0 = _mm256_set1_ps(bias0);
|
s0 = _mm_set1_ps(bias0);
|
||||||
s1 = _mm256_set1_ps(bias1);
|
s1 = _mm_set1_ps(bias1);
|
||||||
s2 = _mm256_set1_ps(bias2);
|
s2 = _mm_set1_ps(bias2);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
s0 = _mm256_castps128_ps256(_mm_loadu_ps(outptr0 + j));
|
s0 = _mm_loadu_ps(outptr0 + j);
|
||||||
s1 = _mm256_castps128_ps256(_mm_loadu_ps(outptr1 + j));
|
s1 = _mm_loadu_ps(outptr1 + j);
|
||||||
s2 = _mm256_castps128_ps256(_mm_loadu_ps(outptr2 + j));
|
s2 = _mm_loadu_ps(outptr2 + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
s0 = _mm256_add_ps(s0, t0);
|
s0 = _mm_add_ps(s0, _mm256_castps256_ps128(t0));
|
||||||
s1 = _mm256_add_ps(s1, t1);
|
s1 = _mm_add_ps(s1, _mm256_castps256_ps128(t1));
|
||||||
s2 = _mm256_add_ps(s2, t2);
|
s2 = _mm_add_ps(s2, _mm256_castps256_ps128(t2));
|
||||||
|
|
||||||
if( relu )
|
if( relu )
|
||||||
{
|
{
|
||||||
__m256 m0 = _mm256_cmp_ps(s0, z, _CMP_GT_OS);
|
__m128 m0 = _mm_cmp_ps(s0, z, _CMP_GT_OS);
|
||||||
__m256 m1 = _mm256_cmp_ps(s1, z, _CMP_GT_OS);
|
__m128 m1 = _mm_cmp_ps(s1, z, _CMP_GT_OS);
|
||||||
__m256 m2 = _mm256_cmp_ps(s2, z, _CMP_GT_OS);
|
__m128 m2 = _mm_cmp_ps(s2, z, _CMP_GT_OS);
|
||||||
s0 = _mm256_xor_ps(s0, _mm256_andnot_ps(m0, _mm256_xor_ps(_mm256_mul_ps(s0, vr0), s0)));
|
s0 = _mm_xor_ps(s0, _mm_andnot_ps(m0, _mm_xor_ps(_mm_mul_ps(s0, vr0), s0)));
|
||||||
s1 = _mm256_xor_ps(s1, _mm256_andnot_ps(m1, _mm256_xor_ps(_mm256_mul_ps(s1, vr1), s1)));
|
s1 = _mm_xor_ps(s1, _mm_andnot_ps(m1, _mm_xor_ps(_mm_mul_ps(s1, vr1), s1)));
|
||||||
s2 = _mm256_xor_ps(s2, _mm256_andnot_ps(m2, _mm256_xor_ps(_mm256_mul_ps(s2, vr2), s2)));
|
s2 = _mm_xor_ps(s2, _mm_andnot_ps(m2, _mm_xor_ps(_mm_mul_ps(s2, vr2), s2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
_mm_storeu_ps(outptr0 + j, _mm256_castps256_ps128(s0));
|
_mm_storeu_ps(outptr0 + j, s0);
|
||||||
_mm_storeu_ps(outptr1 + j, _mm256_castps256_ps128(s1));
|
_mm_storeu_ps(outptr1 + j, s1);
|
||||||
_mm_storeu_ps(outptr2 + j, _mm256_castps256_ps128(s2));
|
_mm_storeu_ps(outptr2 + j, s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
for( ; j < blockSize; j++ )
|
for( ; j < blockSize; j++ )
|
||||||
@ -294,11 +294,63 @@ void fastGEMM1T( const float* vec, const float* weights,
|
|||||||
_mm256_zeroupper();
|
_mm256_zeroupper();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void fastGEMM( const float* aptr, size_t astep, const float* bptr,
|
void fastGEMM( const float* aptr, size_t astep, const float* bptr,
|
||||||
size_t bstep, float* cptr, size_t cstep,
|
size_t bstep, float* cptr, size_t cstep,
|
||||||
int ma, int na, int nb )
|
int ma, int na, int nb )
|
||||||
{
|
{
|
||||||
int n = 0;
|
int n = 0;
|
||||||
|
|
||||||
|
#ifdef CV_AVX512
|
||||||
|
for( ; n <= nb - 32; n += 32 )
|
||||||
|
{
|
||||||
|
for( int m = 0; m < ma; m += 4 )
|
||||||
|
{
|
||||||
|
const float* aptr0 = aptr + astep*m;
|
||||||
|
const float* aptr1 = aptr + astep*std::min(m+1, ma-1);
|
||||||
|
const float* aptr2 = aptr + astep*std::min(m+2, ma-1);
|
||||||
|
const float* aptr3 = aptr + astep*std::min(m+3, ma-1);
|
||||||
|
|
||||||
|
float* cptr0 = cptr + cstep*m;
|
||||||
|
float* cptr1 = cptr + cstep*std::min(m+1, ma-1);
|
||||||
|
float* cptr2 = cptr + cstep*std::min(m+2, ma-1);
|
||||||
|
float* cptr3 = cptr + cstep*std::min(m+3, ma-1);
|
||||||
|
|
||||||
|
__m512 d00 = _mm512_setzero_ps(), d01 = _mm512_setzero_ps();
|
||||||
|
__m512 d10 = _mm512_setzero_ps(), d11 = _mm512_setzero_ps();
|
||||||
|
__m512 d20 = _mm512_setzero_ps(), d21 = _mm512_setzero_ps();
|
||||||
|
__m512 d30 = _mm512_setzero_ps(), d31 = _mm512_setzero_ps();
|
||||||
|
|
||||||
|
for( int k = 0; k < na; k++ )
|
||||||
|
{
|
||||||
|
__m512 a0 = _mm512_set1_ps(aptr0[k]);
|
||||||
|
__m512 a1 = _mm512_set1_ps(aptr1[k]);
|
||||||
|
__m512 a2 = _mm512_set1_ps(aptr2[k]);
|
||||||
|
__m512 a3 = _mm512_set1_ps(aptr3[k]);
|
||||||
|
__m512 b0 = _mm512_loadu_ps(bptr + k*bstep + n);
|
||||||
|
__m512 b1 = _mm512_loadu_ps(bptr + k*bstep + n + 16);
|
||||||
|
d00 = _mm512_fmadd_ps(a0, b0, d00);
|
||||||
|
d01 = _mm512_fmadd_ps(a0, b1, d01);
|
||||||
|
d10 = _mm512_fmadd_ps(a1, b0, d10);
|
||||||
|
d11 = _mm512_fmadd_ps(a1, b1, d11);
|
||||||
|
d20 = _mm512_fmadd_ps(a2, b0, d20);
|
||||||
|
d21 = _mm512_fmadd_ps(a2, b1, d21);
|
||||||
|
d30 = _mm512_fmadd_ps(a3, b0, d30);
|
||||||
|
d31 = _mm512_fmadd_ps(a3, b1, d31);
|
||||||
|
}
|
||||||
|
|
||||||
|
_mm512_storeu_ps(cptr0 + n, d00);
|
||||||
|
_mm512_storeu_ps(cptr0 + n + 16, d01);
|
||||||
|
_mm512_storeu_ps(cptr1 + n, d10);
|
||||||
|
_mm512_storeu_ps(cptr1 + n + 16, d11);
|
||||||
|
_mm512_storeu_ps(cptr2 + n, d20);
|
||||||
|
_mm512_storeu_ps(cptr2 + n + 16, d21);
|
||||||
|
_mm512_storeu_ps(cptr3 + n, d30);
|
||||||
|
_mm512_storeu_ps(cptr3 + n + 16, d31);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
for( ; n <= nb - 16; n += 16 )
|
for( ; n <= nb - 16; n += 16 )
|
||||||
{
|
{
|
||||||
for( int m = 0; m < ma; m += 4 )
|
for( int m = 0; m < ma; m += 4 )
|
||||||
|
Loading…
Reference in New Issue
Block a user