mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
remove old convolution branch, and optimize conv3d and conv1d.
This commit is contained in:
parent
281b790618
commit
71c6339af0
@ -259,7 +259,7 @@ public:
|
||||
std::vector<float> reluslope;
|
||||
Ptr<ActivationLayer> activ;
|
||||
|
||||
Ptr<FastConv2d> fastConv2dImpl;
|
||||
Ptr<FastConv> fastConvImpl;
|
||||
|
||||
#ifdef HAVE_OPENCL
|
||||
Ptr<OCL4DNNConvSpatial<float> > convolutionOp;
|
||||
@ -967,808 +967,6 @@ public:
|
||||
}
|
||||
#endif // HAVE_WEBNN
|
||||
|
||||
class ParallelConv : public cv::ParallelLoopBody
|
||||
{
|
||||
public:
|
||||
enum { BLK_SIZE = 32, BLK_SIZE_CN = 64 };
|
||||
|
||||
const Mat* input_;
|
||||
const Mat* weights_;
|
||||
Mat* output_;
|
||||
int outShape[4]; // used only for conv2d
|
||||
std::vector<size_t> kernel_size, pads_begin, pads_end, strides, dilations;
|
||||
int ngroups_, nstripes_;
|
||||
std::vector<int> ofstab_;
|
||||
const std::vector<float>* biasvec_;
|
||||
const std::vector<float>* reluslope_;
|
||||
const ActivationLayer* activ_;
|
||||
bool is1x1_;
|
||||
bool useAVX;
|
||||
bool useAVX2;
|
||||
bool useAVX512;
|
||||
bool useRVV;
|
||||
bool useLASX;
|
||||
int blk_size_cn;
|
||||
|
||||
ParallelConv()
|
||||
: input_(0), weights_(0), output_(0), ngroups_(0), nstripes_(0),
|
||||
biasvec_(0), reluslope_(0), activ_(0), is1x1_(false), useAVX(false), useAVX2(false), useAVX512(false), useRVV(false)
|
||||
, useLASX(false), blk_size_cn(0)
|
||||
{}
|
||||
|
||||
static void run( const Mat& input, Mat& output, const Mat& weights,
|
||||
const std::vector<float>& biasvec,
|
||||
const std::vector<float>& reluslope,
|
||||
const std::vector<size_t>& kernel_size, const std::vector<size_t>& strides,
|
||||
const std::vector<size_t>& pads_begin, const std::vector<size_t>& pads_end,
|
||||
const std::vector<size_t>& dilations,
|
||||
const ActivationLayer* activ, int ngroups, int nstripes )
|
||||
{
|
||||
size_t karea = std::accumulate(kernel_size.begin(), kernel_size.end(),
|
||||
1, std::multiplies<size_t>());
|
||||
bool isConv1D = input.dims == 3;
|
||||
bool isConv2D = input.dims == 4;
|
||||
bool isConv3D = input.dims == 5;
|
||||
CV_CheckEQ(static_cast<int>(kernel_size.size()), input.dims - 2, "");
|
||||
CV_Assert_N(input.dims == output.dims,
|
||||
input.size[0] == output.size[0],
|
||||
weights.rows == output.size[1],
|
||||
weights.cols == (input.size[1]/ngroups)*karea,
|
||||
input.type() == output.type(),
|
||||
input.type() == weights.type(),
|
||||
input.type() == CV_32FC1,
|
||||
input.isContinuous(),
|
||||
output.isContinuous(),
|
||||
biasvec.size() == (size_t)output.size[1]+2);
|
||||
CV_Check(weights.step1(), weights.step1() % VEC_ALIGN == 0, "");
|
||||
CV_CheckType(weights.type(), CV_32FC1, "");
|
||||
ParallelConv p;
|
||||
|
||||
p.input_ = &input;
|
||||
p.weights_ = &weights;
|
||||
p.output_ = &output;
|
||||
int max_ind = isConv1D? 3: 4;
|
||||
for( int i = 0; i < max_ind; i++ ) p.outShape[i] = output.size[i];
|
||||
p.outShape[1] /= ngroups;
|
||||
|
||||
p.kernel_size = kernel_size; p.strides = strides; p.dilations = dilations;
|
||||
p.pads_begin = pads_begin; p.pads_end = pads_end;
|
||||
|
||||
p.ngroups_ = ngroups;
|
||||
p.nstripes_ = nstripes;
|
||||
|
||||
int inpCnAll = input.size[1];
|
||||
int depth = (input.dims == 5) ? input.size[2] : 1;
|
||||
int width = input.size[input.dims - 1];
|
||||
int height = isConv1D? 1 : input.size[input.dims - 2];
|
||||
int inpCn = inpCnAll / ngroups;
|
||||
|
||||
p.is1x1_ = (isConv2D && kernel_size[0] == 1 && kernel_size[1] == 1 &&
|
||||
pads_begin[0] == 0 && pads_begin[1] == 0) ||
|
||||
(isConv1D && pads_begin[0] == 0 && kernel_size[0] == 1);
|
||||
|
||||
p.useAVX = checkHardwareSupport(CPU_AVX) && isConv2D;
|
||||
p.useAVX2 = checkHardwareSupport(CPU_AVX2) && isConv2D;
|
||||
p.useAVX512 = CV_CPU_HAS_SUPPORT_AVX512_SKX && isConv2D;
|
||||
p.useRVV = checkHardwareSupport(CPU_RVV) && isConv2D;
|
||||
p.useLASX = checkHardwareSupport(CPU_LASX) && isConv2D;
|
||||
|
||||
int kernel_d = isConv3D? kernel_size[0] : 1;
|
||||
int kernel_h = isConv1D? 1 : kernel_size[kernel_size.size() - 2];
|
||||
int kernel_w = kernel_size.back();
|
||||
|
||||
int blk_size_cn0 = cvCeil(800./(kernel_w*kernel_h));
|
||||
int ncn = 16;
|
||||
while (ncn*2 < blk_size_cn0 && ncn < inpCn)
|
||||
ncn *= 2;
|
||||
ncn = std::min(ncn, inpCn);
|
||||
p.blk_size_cn = ncn;
|
||||
|
||||
int dil_d = isConv3D? dilations[0] : 1;
|
||||
int dil_h = isConv1D? 1 : dilations[dilations.size() - 2];
|
||||
int dil_w = dilations.back();
|
||||
|
||||
p.ofstab_.resize(karea * ncn);
|
||||
int* ofstab = &p.ofstab_[0];
|
||||
|
||||
if (isConv1D)
|
||||
{
|
||||
for( int k = 0; k < ncn; k++ )
|
||||
for( int k_c = 0; k_c < kernel_w; k_c++ )
|
||||
ofstab[k*kernel_w + k_c] = k*width + k_c*dil_w;
|
||||
}
|
||||
else if (isConv2D)
|
||||
{
|
||||
for( int k = 0; k < ncn; k++ )
|
||||
for( int k_r = 0; k_r < kernel_h; k_r++ )
|
||||
for( int k_c = 0; k_c < kernel_w; k_c++ )
|
||||
ofstab[(k*kernel_h + k_r)*kernel_w + k_c] =
|
||||
(k*height + k_r*dil_h)*width + k_c*dil_w;
|
||||
}
|
||||
else
|
||||
{
|
||||
for( int k = 0; k < ncn; k++ )
|
||||
for (int k_d = 0; k_d < kernel_d; k_d++)
|
||||
for( int k_r = 0; k_r < kernel_h; k_r++ )
|
||||
for( int k_c = 0; k_c < kernel_w; k_c++ )
|
||||
ofstab[(k*kernel_d*kernel_h + k_d*kernel_h + k_r)*kernel_w + k_c] =
|
||||
(k*depth*height + k_d*dil_d*height + k_r*dil_h)*width + k_c*dil_w;
|
||||
}
|
||||
|
||||
p.biasvec_ = &biasvec;
|
||||
p.reluslope_ = &reluslope;
|
||||
p.activ_ = p.reluslope_->empty() ? activ : 0;
|
||||
|
||||
parallel_for_(Range(0, nstripes), p, nstripes);
|
||||
}
|
||||
|
||||
virtual void operator ()(const Range &r0) const CV_OVERRIDE
|
||||
{
|
||||
const int valign = ConvolutionLayerImpl::VEC_ALIGN;
|
||||
int ngroups = ngroups_, batchSize = input_->size[0]*ngroups;
|
||||
bool isConv1D = input_->dims == 3;
|
||||
bool isConv2D = input_->dims == 4;
|
||||
bool isConv3D = input_->dims == 5;
|
||||
|
||||
int outW = output_->size[output_->dims - 1];
|
||||
int outH = isConv1D? 1 : output_->size[output_->dims - 2];
|
||||
int outCn = output_->size[1]/ngroups;
|
||||
|
||||
int depth = isConv3D? input_->size[2] : 1;
|
||||
int height = isConv1D? 1 : input_->size[input_->dims - 2];
|
||||
int width = input_->size[input_->dims - 1];
|
||||
int inpCn = input_->size[1]/ngroups;
|
||||
|
||||
const int nstripes = nstripes_;
|
||||
|
||||
int kernel_d = isConv3D? kernel_size[0] : 1;
|
||||
int kernel_h = isConv1D? 1 : kernel_size[kernel_size.size() - 2];
|
||||
int kernel_w = kernel_size.back();
|
||||
int karea = kernel_w*kernel_h*kernel_d;
|
||||
|
||||
int pad_d = isConv3D? pads_begin[0] : 0;
|
||||
int pad_t = isConv1D? 0 : pads_begin[pads_begin.size() - 2];
|
||||
int pad_l = pads_begin.back();
|
||||
|
||||
int stride_d = isConv3D? strides[0] : 0;
|
||||
int stride_h = isConv1D? 0 : strides[strides.size() - 2];
|
||||
int stride_w = strides.back();
|
||||
|
||||
int dilation_d = isConv3D? dilations[0] : 1;
|
||||
int dilation_h = isConv1D? 1 : dilations[dilations.size() - 2];
|
||||
int dilation_w = dilations.back();
|
||||
|
||||
int i, j, k, d;
|
||||
int inpPlaneSize = (int)input_->total(2);
|
||||
int outPlaneSize = (int)output_->total(2);
|
||||
bool is1x1 = is1x1_;
|
||||
|
||||
int stripesPerSample;
|
||||
int stripeSize;
|
||||
Range r = r0;
|
||||
bool depthWiseConvolution = !is1x1 && isConv2D && ngroups > 1 && inpCn == 1 &&
|
||||
outCn == 1 && kernel_d == 1 && dilation_d == 1 && stride_d == 0 && pad_d == 0 &&
|
||||
width >= 16 + dilation_w*(kernel_w - 1);
|
||||
// for now only 3x3 depth-wise convolutions are supported
|
||||
depthWiseConvolution = depthWiseConvolution && kernel_w == 3 && kernel_h == 3 &&
|
||||
// computing at most 1 pixel from each side can involve padding
|
||||
max(stride_w, dilation_w) >= pad_l && max(stride_h, dilation_h) >= pad_t &&
|
||||
pad_l <= 1 && pad_t <= 1;
|
||||
|
||||
if( !depthWiseConvolution && nstripes >= batchSize*2 )
|
||||
{
|
||||
stripesPerSample = nstripes/batchSize;
|
||||
stripeSize = (int)alignSize((outPlaneSize + stripesPerSample - 1)/stripesPerSample, valign);
|
||||
stripeSize = std::min(stripeSize, outPlaneSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
stripesPerSample = 1;
|
||||
int samplesPerStripe = std::max((batchSize + nstripes - 1)/nstripes, 1);
|
||||
r.start *= samplesPerStripe;
|
||||
r.end *= samplesPerStripe;
|
||||
stripeSize = outPlaneSize;
|
||||
}
|
||||
|
||||
const float* data_inp0_ = input_->ptr<float>();
|
||||
const int* ofstab = &ofstab_[0];
|
||||
const float* wptr_orig_ = weights_->ptr<float>();
|
||||
size_t wstep = weights_->step1();
|
||||
const float* biasptr_ = &biasvec_->at(0);
|
||||
const float* reluptr_ = reluslope_->empty() ? 0 : &reluslope_->at(0);
|
||||
float* data_out0_ = output_->ptr<float>();
|
||||
AutoBuffer<float> rowbuf0_;
|
||||
float* rowbuf0 = 0;
|
||||
bool use_rowbuf = !depthWiseConvolution;
|
||||
int blk_size = depthWiseConvolution ? outPlaneSize : min((int)BLK_SIZE, stripeSize);
|
||||
|
||||
// im2row buffer is not used for depth-wise convolution
|
||||
if(use_rowbuf)
|
||||
{
|
||||
size_t rowbufsz = alignSize(karea*blk_size_cn, valign)*min((int)BLK_SIZE, blk_size);
|
||||
//printf("karea=%d, blk_size_cn=%d, rowbufsz=%d, stripeSize=%d\n", karea, blk_size_cn, (int)rowbufsz, stripeSize);
|
||||
rowbuf0_.allocate(rowbufsz + valign);
|
||||
rowbuf0 = alignPtr(rowbuf0_.data(), (int)(valign*sizeof(float)));
|
||||
// we clear the buffer once; ultimately, it lets us to avoid
|
||||
// tail processing after running the unrolled/vectorized loop.
|
||||
// the main idea is to make sure that the tail (a.k.a. padding) of each row
|
||||
// (i.e. the elements with indices between vsz=karea*ncn and vsz_a)
|
||||
// does not contain NaNs or Infs. Because the padding in the weights
|
||||
// matrix is explicitly initialized with 0's, we handle all other
|
||||
// cases nicely, i.e. we can skip expliciting re-initialization
|
||||
// of the padding - we just retain elements from the previous iteration
|
||||
// of the loop over channels (cn0).
|
||||
memset(rowbuf0, 0, rowbufsz*sizeof(rowbuf0[0]) );
|
||||
}
|
||||
|
||||
for( int stripe = r.start; stripe < r.end; stripe++ )
|
||||
{
|
||||
int subsampleIdx = stripe/stripesPerSample;
|
||||
if( subsampleIdx >= batchSize )
|
||||
break;
|
||||
int stripeStart = (int)((stripe - subsampleIdx*stripesPerSample)*stripeSize);
|
||||
int stripeEnd = (int)std::min(stripeStart + stripeSize, outPlaneSize);
|
||||
const float* data_inp0 = data_inp0_ + subsampleIdx*inpPlaneSize*inpCn;
|
||||
float* data_out0 = data_out0_ + subsampleIdx*outPlaneSize*outCn;
|
||||
int startOutCn = (subsampleIdx % ngroups)*outCn;
|
||||
const float* wptr_orig = wptr_orig_ + wstep*startOutCn;
|
||||
const float* biasptr = biasptr_ + startOutCn;
|
||||
|
||||
for( int cn0 = 0; cn0 < inpCn; cn0 += blk_size_cn )
|
||||
{
|
||||
int cn1 = std::min(cn0 + blk_size_cn, inpCn);
|
||||
int ncn = cn1 - cn0, vsz = karea*ncn;
|
||||
int vsz_a = (int)alignSize(vsz, valign);
|
||||
const float* wptr = wptr_orig + cn0*karea;
|
||||
// we apply [Channels][P]ReLU (if any) during the final pass only.
|
||||
const float* relu = cn1 == inpCn && reluptr_ ? reluptr_ + startOutCn : 0;
|
||||
|
||||
for( int ofs0 = stripeStart; ofs0 < stripeEnd; ofs0 += blk_size )
|
||||
{
|
||||
int ofs, ofs1 = std::min(ofs0 + blk_size, stripeEnd);
|
||||
int bsz = ofs1 - ofs0;
|
||||
|
||||
int out_d = ofs0 / (outH * outW);
|
||||
int out_i = (ofs0 - out_d * outH * outW) / outW;
|
||||
int out_j = ofs0 % outW;
|
||||
|
||||
if (depthWiseConvolution)
|
||||
{
|
||||
CV_Assert(out_i == 0 && out_j == 0);
|
||||
int in_d = out_d * stride_d - pad_d;
|
||||
const float* inptr_ = data_inp0 + (cn0*depth*height + in_d*height)*width;
|
||||
float* outptr_ = data_out0 + ofs0;
|
||||
|
||||
#if CV_TRY_AVX2
|
||||
if(useAVX2)
|
||||
opt_AVX2::fastDepthwiseConv(wptr, kernel_h, kernel_w,
|
||||
stride_h, stride_w, dilation_h, dilation_w, pad_t, pad_l,
|
||||
biasptr, relu, inptr_, height, width, outptr_, out_d, outH, outW);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_AVX
|
||||
if(useAVX)
|
||||
opt_AVX::fastDepthwiseConv(wptr, kernel_h, kernel_w,
|
||||
stride_h, stride_w, dilation_h, dilation_w, pad_t, pad_l,
|
||||
biasptr, relu, inptr_, height, width, outptr_, out_d, outH, outW);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_RVV
|
||||
if(useRVV)
|
||||
opt_RVV::fastDepthwiseConv(wptr, kernel_h, kernel_w,
|
||||
stride_h, stride_w, dilation_h, dilation_w, pad_t, pad_l,
|
||||
biasptr, relu, inptr_, height, width, outptr_, out_d, outH, outW);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_LASX
|
||||
if(useLASX)
|
||||
opt_LASX::fastDepthwiseConv(wptr, kernel_h, kernel_w,
|
||||
stride_h, stride_w, dilation_h, dilation_w, pad_t, pad_l,
|
||||
biasptr, relu, inptr_, height, width, outptr_, out_d, outH, outW);
|
||||
else
|
||||
#endif
|
||||
{
|
||||
const float w00_ = wptr[0], w01_ = wptr[1], w02_ = wptr[2],
|
||||
w10 = wptr[3], w11 = wptr[4], w12 = wptr[5],
|
||||
w20_ = wptr[6], w21_ = wptr[7], w22_ = wptr[8];
|
||||
int outW1 = min(outW, (width - dilation_w*(kernel_w - 1) + pad_l)/stride_w);
|
||||
float relu_coeff = relu ? relu[out_d] : 1.f, bias = biasptr[out_d];
|
||||
|
||||
for (int out_i = 0; out_i < outH; out_i++)
|
||||
{
|
||||
int in_i = out_i * stride_h - pad_t, out_j = 0;
|
||||
const float* imgptr0 = inptr_ + in_i*width;
|
||||
const float* imgptr1 = imgptr0 + dilation_h*width;
|
||||
const float* imgptr2 = imgptr0 + (dilation_h*2)*width;
|
||||
float out, w00 = w00_, w01 = w01_, w02 = w02_;
|
||||
float w20 = w20_, w21 = w21_, w22 = w22_;
|
||||
if (in_i < 0)
|
||||
{
|
||||
w00 = w01 = w02 = 0.f;
|
||||
imgptr0 = imgptr1;
|
||||
}
|
||||
else if (in_i + dilation_h*(kernel_h-1) >= height)
|
||||
{
|
||||
w20 = w21 = w22 = 0.f;
|
||||
imgptr2 = imgptr1;
|
||||
}
|
||||
float* outptr = outptr_ + out_i*outW;
|
||||
if (pad_l > 0)
|
||||
{
|
||||
out = imgptr0[0]*w01 + imgptr0[dilation_w]*w02 +
|
||||
imgptr1[0]*w11 + imgptr1[dilation_w]*w12 +
|
||||
imgptr2[0]*w21 + imgptr2[dilation_w]*w22 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[0] = out;
|
||||
out_j = 1;
|
||||
}
|
||||
|
||||
#if CV_SIMD
|
||||
// maybe with AVX or AVX512 strided depthwise convolution
|
||||
// can be accelerated with vector code, but with 4xfloat vectors
|
||||
// it's hardly the case
|
||||
if( stride_w == 1 )
|
||||
{
|
||||
const int VECSZ = v_float32::nlanes;
|
||||
const int out_delta = VECSZ/stride_w;
|
||||
v_float32 vw00 = vx_setall_f32(w00), vw01 = vx_setall_f32(w01), vw02 = vx_setall_f32(w02),
|
||||
vw10 = vx_setall_f32(w10), vw11 = vx_setall_f32(w11), vw12 = vx_setall_f32(w12),
|
||||
vw20 = vx_setall_f32(w20), vw21 = vx_setall_f32(w21), vw22 = vx_setall_f32(w22);
|
||||
v_float32 z = vx_setzero_f32(), vbias = vx_setall_f32(bias), vrc = vx_setall_f32(relu_coeff);
|
||||
for( ; out_j < outW1; out_j += out_delta )
|
||||
{
|
||||
if (out_j + out_delta > outW1)
|
||||
{
|
||||
if (out_j <= pad_l)
|
||||
break;
|
||||
out_j = outW1 - out_delta;
|
||||
}
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
v_float32 v00 = vx_load(imgptr0 + in_j),
|
||||
v01 = vx_load(imgptr0 + in_j + dilation_w),
|
||||
v02 = vx_load(imgptr0 + in_j + dilation_w*2),
|
||||
v10 = vx_load(imgptr1 + in_j),
|
||||
v11 = vx_load(imgptr1 + in_j + dilation_w),
|
||||
v12 = vx_load(imgptr1 + in_j + dilation_w*2),
|
||||
v20 = vx_load(imgptr2 + in_j),
|
||||
v21 = vx_load(imgptr2 + in_j + dilation_w),
|
||||
v22 = vx_load(imgptr2 + in_j + dilation_w*2);
|
||||
|
||||
v_float32 vout = v00*vw00 + v01*vw01 + v02*vw02 +
|
||||
v10*vw10 + v11*vw11 + v12*vw12 +
|
||||
v20*vw20 + v21*vw21 + v22*vw22 + vbias;
|
||||
if (relu)
|
||||
vout = v_select(vout > z, vout, vout*vrc);
|
||||
v_store(outptr + out_j, vout);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for (; out_j < outW1; out_j++)
|
||||
{
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
out = imgptr0[in_j]*w00 + imgptr0[in_j + dilation_w]*w01 + imgptr0[in_j + dilation_w*2]*w02 +
|
||||
imgptr1[in_j]*w10 + imgptr1[in_j + dilation_w]*w11 + imgptr1[in_j + dilation_w*2]*w12 +
|
||||
imgptr2[in_j]*w20 + imgptr2[in_j + dilation_w]*w21 + imgptr2[in_j + dilation_w*2]*w22 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[out_j] = out;
|
||||
}
|
||||
|
||||
for (; out_j < outW; out_j++ )
|
||||
{
|
||||
int in_j0 = out_j * stride_w - pad_l, in_j1 = in_j0 + dilation_w, in_j2 = in_j0 + dilation_w*2;
|
||||
float s0 = 1.f, s1 = 1.f, s2 = 1.f;
|
||||
if (in_j0 >= width)
|
||||
{
|
||||
in_j0 = 0;
|
||||
s0 = 0.f;
|
||||
}
|
||||
if (in_j1 >= width)
|
||||
{
|
||||
in_j1 = 0;
|
||||
s1 = 0.f;
|
||||
}
|
||||
if (in_j2 >= width)
|
||||
{
|
||||
in_j2 = 0;
|
||||
s2 = 0.f;
|
||||
}
|
||||
out = imgptr0[in_j0]*w00*s0 + imgptr0[in_j1]*w01*s1 + imgptr0[in_j2]*w02*s2 +
|
||||
imgptr1[in_j0]*w10*s0 + imgptr1[in_j1]*w11*s1 + imgptr1[in_j2]*w12*s2 +
|
||||
imgptr2[in_j0]*w20*s0 + imgptr2[in_j1]*w21*s1 + imgptr2[in_j2]*w22*s2 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[out_j] = out;
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// do im2row for a part of input tensor
|
||||
float* rowbuf = rowbuf0;
|
||||
|
||||
if (isConv1D)
|
||||
{
|
||||
for( ofs = ofs0; ofs < ofs1; out_j = 0, ++out_i )
|
||||
{
|
||||
int delta = std::min(ofs1 - ofs, outW - out_j);
|
||||
int out_j1 = out_j + delta;
|
||||
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
const float* imgptr = data_inp0 + cn0*width + in_j;
|
||||
ofs += delta;
|
||||
|
||||
// do im2row for a part of input tensor
|
||||
if( is1x1 )
|
||||
{
|
||||
for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w )
|
||||
{
|
||||
for( k = 0; k < vsz; k++ )
|
||||
rowbuf[k] = imgptr[k*inpPlaneSize];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w, in_j += stride_w )
|
||||
{
|
||||
// this condition should be true for most of the tensor elements, i.e.
|
||||
// most of the time the kernel aperture is inside the tensor X-Y plane.
|
||||
if( out_j + 2 <= out_j1 && 0 <= in_j && in_j + stride_w*2 <= width - (kernel_w-1)*dilation_w )
|
||||
{
|
||||
for( k = 0; k < vsz; k++ )
|
||||
{
|
||||
int k1 = ofstab[k];
|
||||
float v0 = imgptr[k1];
|
||||
float v1 = imgptr[k1 + stride_w];
|
||||
rowbuf[k] = v0;
|
||||
rowbuf[k+vsz_a] = v1;
|
||||
}
|
||||
out_j++;
|
||||
rowbuf += vsz_a;
|
||||
imgptr += stride_w;
|
||||
in_j += stride_w;
|
||||
}
|
||||
else
|
||||
{
|
||||
int i0 = std::max(0, (-in_j + dilation_w-1)/dilation_w);
|
||||
int i1 = std::min(kernel_w, (width - in_j + dilation_w-1)/dilation_w);
|
||||
|
||||
// here some non-continuous sub-row of the row will not be
|
||||
// filled from the tensor; we need to make sure that the uncovered
|
||||
// elements are explicitly set to 0's. the easiest way is to
|
||||
// set all the elements to 0's before the loop.
|
||||
memset(rowbuf, 0, vsz*sizeof(rowbuf[0]));
|
||||
for( k = 0; k < ncn; k++ )
|
||||
{
|
||||
for( i = i0; i < i1; i++ )
|
||||
{
|
||||
int imgofs = k*width + i*dilation_w;
|
||||
rowbuf[k*kernel_w + i] = imgptr[imgofs];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (isConv2D)
|
||||
{
|
||||
if( is1x1 && stride_w == 1 && stride_h == 1 )
|
||||
{
|
||||
const float* imgptr = data_inp0 + (cn0*height + out_i)*width + out_j;
|
||||
for( int j = 0; j < bsz; j++, rowbuf += vsz_a )
|
||||
{
|
||||
if( j + 4 <= bsz )
|
||||
{
|
||||
k = 0;
|
||||
#if CV_SIMD128
|
||||
for( ; k <= vsz - 4; k += 4 )
|
||||
{
|
||||
const float* inp = imgptr + j + k*inpPlaneSize;
|
||||
v_float32x4 p0 = v_load(inp), p1 = v_load(inp + inpPlaneSize);
|
||||
v_float32x4 p2 = v_load(inp + inpPlaneSize*2), p3 = v_load(inp + inpPlaneSize*3);
|
||||
v_float32x4 r0, r1, r2, r3;
|
||||
v_transpose4x4(p0, p1, p2, p3, r0, r1, r2, r3);
|
||||
v_store(rowbuf + k, r0);
|
||||
v_store(rowbuf + k + vsz_a, r1);
|
||||
v_store(rowbuf + k + vsz_a*2, r2);
|
||||
v_store(rowbuf + k + vsz_a*3, r3);
|
||||
}
|
||||
#endif
|
||||
for( ; k < vsz; k++ )
|
||||
{
|
||||
const float* inp = imgptr + j + k*inpPlaneSize;
|
||||
float v0 = inp[0], v1 = inp[1], v2 = inp[2], v3 = inp[3];
|
||||
rowbuf[k] = v0;
|
||||
rowbuf[k + vsz_a] = v1;
|
||||
rowbuf[k + vsz_a*2] = v2;
|
||||
rowbuf[k + vsz_a*3] = v3;
|
||||
}
|
||||
j += 3;
|
||||
rowbuf += vsz_a*3;
|
||||
}
|
||||
else
|
||||
{
|
||||
for( k = 0; k < vsz; k++ )
|
||||
{
|
||||
rowbuf[k] = imgptr[j + k*inpPlaneSize];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
for( ofs = ofs0; ofs < ofs1; out_j = 0, ++out_i )
|
||||
{
|
||||
int delta = std::min(ofs1 - ofs, outW - out_j);
|
||||
int out_j1 = out_j + delta;
|
||||
|
||||
int in_i = out_i * stride_h - pad_t;
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
const float* imgptr = data_inp0 + (cn0*height + in_i)*width + in_j;
|
||||
ofs += delta;
|
||||
|
||||
// do im2row for a part of input tensor
|
||||
if( is1x1 )
|
||||
{
|
||||
for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w )
|
||||
{
|
||||
for( k = 0; k < vsz; k++ )
|
||||
rowbuf[k] = imgptr[k*inpPlaneSize];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
bool ok_i = 0 <= in_i && in_i < height - (kernel_h-1)*dilation_h;
|
||||
int i0 = std::max(0, (-in_i + dilation_h-1)/dilation_h);
|
||||
int i1 = std::min(kernel_h, (height - in_i + dilation_h-1)/dilation_h);
|
||||
|
||||
for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w, in_j += stride_w )
|
||||
{
|
||||
// this condition should be true for most of the tensor elements, i.e.
|
||||
// most of the time the kernel aperture is inside the tensor X-Y plane.
|
||||
if( ok_i && out_j + 2 <= out_j1 && 0 <= in_j && in_j + stride_w*2 <= width - (kernel_w-1)*dilation_w )
|
||||
{
|
||||
for( k = 0; k < vsz; k++ )
|
||||
{
|
||||
int k1 = ofstab[k];
|
||||
float v0 = imgptr[k1];
|
||||
float v1 = imgptr[k1 + stride_w];
|
||||
rowbuf[k] = v0;
|
||||
rowbuf[k+vsz_a] = v1;
|
||||
}
|
||||
out_j++;
|
||||
rowbuf += vsz_a;
|
||||
imgptr += stride_w;
|
||||
in_j += stride_w;
|
||||
}
|
||||
else
|
||||
{
|
||||
int j0 = std::max(0, (-in_j + dilation_w-1)/dilation_w);
|
||||
int j1 = std::min(kernel_w, (width - in_j + dilation_w-1)/dilation_w);
|
||||
|
||||
// here some non-continuous sub-row of the row will not be
|
||||
// filled from the tensor; we need to make sure that the uncovered
|
||||
// elements are explicitly set to 0's. the easiest way is to
|
||||
// set all the elements to 0's before the loop.
|
||||
memset(rowbuf, 0, vsz*sizeof(rowbuf[0]));
|
||||
for( k = 0; k < ncn; k++ )
|
||||
{
|
||||
for( i = i0; i < i1; i++ )
|
||||
{
|
||||
for( j = j0; j < j1; j++ )
|
||||
{
|
||||
int imgofs = k*(width*height) + i*(dilation_h*width) + j*dilation_w;
|
||||
rowbuf[(k*kernel_h + i)*kernel_w + j] = imgptr[imgofs];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for( ofs = ofs0; ofs < ofs1; out_d += (out_i + 1) / outH, out_i = (out_i + 1) % outH, out_j = 0 )
|
||||
{
|
||||
int delta = std::min(ofs1 - ofs, outW - out_j);
|
||||
int out_j1 = out_j + delta;
|
||||
|
||||
int in_d = out_d * stride_d - pad_d;
|
||||
int in_i = out_i * stride_h - pad_t;
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
const float* imgptr = data_inp0 + (cn0*depth*height + in_d*height + in_i)*width + in_j;
|
||||
ofs += delta;
|
||||
|
||||
int d0 = std::max(0, (-in_d + dilation_d - 1) / dilation_d);
|
||||
int d1 = std::min(kernel_d, (depth - in_d + dilation_d - 1) / dilation_d);
|
||||
|
||||
int i0 = std::max(0, (-in_i + dilation_h-1)/dilation_h);
|
||||
int i1 = std::min(kernel_h, (height - in_i + dilation_h-1)/dilation_h);
|
||||
|
||||
for( ; out_j < out_j1; out_j++, rowbuf += vsz_a, imgptr += stride_w, in_j += stride_w )
|
||||
{
|
||||
int j0 = std::max(0, (-in_j + dilation_w-1)/dilation_w);
|
||||
int j1 = std::min(kernel_w, (width - in_j + dilation_w-1)/dilation_w);
|
||||
|
||||
// here some non-continuous sub-row of the row will not be
|
||||
// filled from the tensor; we need to make sure that the uncovered
|
||||
// elements are explicitly set to 0's. the easiest way is to
|
||||
// set all the elements to 0's before the loop.
|
||||
memset(rowbuf, 0, vsz*sizeof(rowbuf[0]));
|
||||
for( k = 0; k < ncn; k++ )
|
||||
{
|
||||
for ( d = d0; d < d1; d++)
|
||||
{
|
||||
for( i = i0; i < i1; i++ )
|
||||
{
|
||||
for( j = j0; j < j1; j++ )
|
||||
{
|
||||
int imgofs = k*(depth*width*height) + d*dilation_d*width*height + i*(dilation_h*width) + j*dilation_w;
|
||||
rowbuf[(k*kernel_d*kernel_h + d*kernel_h + i)*kernel_w + j] = imgptr[imgofs];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// now compute dot product of the weights
|
||||
// and im2row-transformed part of the tensor
|
||||
#if CV_TRY_AVX512_SKX
|
||||
/* AVX512 convolution requires an alignment of 16, and ROI is only there for larger vector sizes */
|
||||
if(useAVX512)
|
||||
opt_AVX512_SKX::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
||||
outShape, bsz, vsz, vsz_a, relu, cn0 == 0);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_AVX2
|
||||
if(useAVX2)
|
||||
opt_AVX2::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
||||
outShape, bsz, vsz, vsz_a, relu, cn0 == 0);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_AVX
|
||||
if(useAVX)
|
||||
opt_AVX::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
||||
outShape, bsz, vsz, vsz_a, relu, cn0 == 0);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_RVV
|
||||
if(useRVV)
|
||||
opt_RVV::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
||||
outShape, bsz, vsz, vsz_a, relu, cn0 == 0);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_LASX
|
||||
if(useLASX)
|
||||
opt_LASX::fastConv(wptr, wstep, biasptr, rowbuf0, data_out0 + ofs0,
|
||||
outShape, bsz, vsz, vsz_a, relu, cn0 == 0);
|
||||
else
|
||||
#endif
|
||||
for( int i = 0; i < outCn; i += 2 )
|
||||
{
|
||||
const float* wptr0 = wptr + i*wstep;
|
||||
const float* wptr1 = wptr0 + wstep;
|
||||
float* outptr0 = data_out0 + ofs0 + i*outPlaneSize;
|
||||
float* outptr1 = outptr0 + outPlaneSize;
|
||||
float bias0 = biasptr[i], bias1 = biasptr[i+1];
|
||||
float r0 = 1.f, r1 = 1.f;
|
||||
|
||||
if( i+1 >= outCn )
|
||||
{
|
||||
wptr1 = wptr0;
|
||||
outptr1 = outptr0;
|
||||
bias1 = bias0;
|
||||
}
|
||||
|
||||
if( relu )
|
||||
{
|
||||
r0 = relu[i]; r1 = relu[i+1];
|
||||
if( i+1 >= outCn )
|
||||
r1 = r0;
|
||||
}
|
||||
|
||||
int j = 0;
|
||||
#if CV_SIMD128
|
||||
v_float32x4 vr0 = v_setall_f32(r0), vr1 = v_setall_f32(r1), z = v_setzero_f32();
|
||||
|
||||
for( ; j <= bsz - 4; j += 4 )
|
||||
{
|
||||
const float* rptr = rowbuf0 + j*vsz_a;
|
||||
v_float32x4 s0, s1;
|
||||
|
||||
if( cn0 == 0 )
|
||||
{
|
||||
s0 = v_setall_f32(bias0);
|
||||
s1 = v_setall_f32(bias1);
|
||||
}
|
||||
else
|
||||
{
|
||||
s0 = v_load(outptr0 + j);
|
||||
s1 = v_load(outptr1 + j);
|
||||
}
|
||||
|
||||
v_float32x4 vs00 = v_setzero_f32(), vs01 = v_setzero_f32(),
|
||||
vs02 = v_setzero_f32(), vs03 = v_setzero_f32(),
|
||||
vs10 = v_setzero_f32(), vs11 = v_setzero_f32(),
|
||||
vs12 = v_setzero_f32(), vs13 = v_setzero_f32();
|
||||
for( k = 0; k < vsz; k += 4, rptr += 4 )
|
||||
{
|
||||
v_float32x4 w0 = v_load_aligned(wptr0 + k);
|
||||
v_float32x4 w1 = v_load_aligned(wptr1 + k);
|
||||
v_float32x4 r0 = v_load_aligned(rptr);
|
||||
v_float32x4 r1 = v_load_aligned(rptr + vsz_a);
|
||||
v_float32x4 r2 = v_load_aligned(rptr + vsz_a*2);
|
||||
v_float32x4 r3 = v_load_aligned(rptr + vsz_a*3);
|
||||
|
||||
vs00 = v_fma(w0, r0, vs00);
|
||||
vs01 = v_fma(w0, r1, vs01);
|
||||
vs02 = v_fma(w0, r2, vs02);
|
||||
vs03 = v_fma(w0, r3, vs03);
|
||||
|
||||
vs10 = v_fma(w1, r0, vs10);
|
||||
vs11 = v_fma(w1, r1, vs11);
|
||||
vs12 = v_fma(w1, r2, vs12);
|
||||
vs13 = v_fma(w1, r3, vs13);
|
||||
}
|
||||
s0 += v_reduce_sum4(vs00, vs01, vs02, vs03);
|
||||
s1 += v_reduce_sum4(vs10, vs11, vs12, vs13);
|
||||
if( relu )
|
||||
{
|
||||
s0 = v_select(s0 > z, s0, s0*vr0);
|
||||
s1 = v_select(s1 > z, s1, s1*vr1);
|
||||
}
|
||||
|
||||
v_store(outptr0 + j, s0);
|
||||
v_store(outptr1 + j, s1);
|
||||
}
|
||||
#endif
|
||||
for( ; j < bsz; j++ )
|
||||
{
|
||||
const float* rptr = rowbuf0 + j*vsz_a;
|
||||
float s00, s10;
|
||||
|
||||
if( cn0 == 0 )
|
||||
{
|
||||
s00 = bias0;
|
||||
s10 = bias1;
|
||||
}
|
||||
else
|
||||
{
|
||||
s00 = outptr0[j];
|
||||
s10 = outptr1[j];
|
||||
}
|
||||
|
||||
for( k = 0; k < vsz; k++ )
|
||||
{
|
||||
float r0 = rptr[k];
|
||||
s00 += wptr0[k]*r0;
|
||||
s10 += wptr1[k]*r0;
|
||||
}
|
||||
if( relu )
|
||||
{
|
||||
s00 = s00 > 0.f ? s00 : s00*r0;
|
||||
s10 = s10 > 0.f ? s10 : s10*r1;
|
||||
}
|
||||
|
||||
outptr0[j] = s00;
|
||||
outptr1[j] = s10;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if( activ_ )
|
||||
activ_->forwardSlice(data_out0 + stripeStart, data_out0 + stripeStart,
|
||||
(int)(stripeEnd - stripeStart),
|
||||
outPlaneSize, startOutCn, startOutCn + outCn);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef HAVE_OPENCL
|
||||
bool forward_ocl(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
|
||||
{
|
||||
@ -2096,40 +1294,27 @@ public:
|
||||
#endif
|
||||
{
|
||||
int nstripes = std::max(getNumThreads(), 1);
|
||||
int conv_dim = CONV_2D;
|
||||
if (inputs[0].dims == 3)
|
||||
conv_dim = CONV_1D;
|
||||
if (inputs[0].dims == 5)
|
||||
conv_dim = CONV_3D;
|
||||
|
||||
// Initialization of FastCovn2d, pack weight.
|
||||
if ((!fastConv2dImpl || variableWeight) && inputs[0].dims == 4)
|
||||
if (!fastConvImpl || variableWeight)
|
||||
{
|
||||
int K = outputs[0].size[1];
|
||||
int C = inputs[0].size[1];
|
||||
int Hk = kernel_size[kernel_size.size() - 2];
|
||||
int Wk = kernel_size.back();
|
||||
|
||||
// Winograd only works when input h and w >= 12.
|
||||
bool canUseWinograd = useWinograd && conv_dim == CONV_2D && inputs[0].size[2] >= 12 && inputs[0].size[3] >= 12;
|
||||
|
||||
CV_Assert(outputs[0].size[1] % ngroups == 0);
|
||||
int stride_h = strides[strides.size() - 2];
|
||||
int stride_w = strides.back();
|
||||
|
||||
int dilation_h = dilations[dilations.size() - 2];
|
||||
int dilation_w = dilations.back();
|
||||
|
||||
// Winograd only works well on input h and w >12.
|
||||
bool canUseWinograd = useWinograd && inputs[0].size[2] >= 12 && inputs[0].size[3] >= 12;
|
||||
|
||||
fastConv2dImpl = initFastConv2d(ngroups, K, C, Hk, Wk, stride_w, stride_h, dilation_w,
|
||||
dilation_h, pads_begin, pads_end, weightsMat, &biasvec[0], canUseWinograd);
|
||||
fastConvImpl = initFastConv(weightsMat, &biasvec[0], ngroups, K, C, kernel_size, strides,
|
||||
dilations, pads_begin, pads_end, conv_dim, canUseWinograd);
|
||||
}
|
||||
|
||||
if (fastConv2dImpl)
|
||||
{
|
||||
runFastConv2d(inputs[0], outputs[0], fastConv2dImpl, nstripes, activ, fusedAdd);
|
||||
return;
|
||||
}
|
||||
|
||||
//TODO: Add support of Conv1D and Conv3D to fastConv, and remove the old Conv branch.
|
||||
// Use only for Conv1D and Conv3D.
|
||||
CV_Assert(!fusedAdd);
|
||||
ParallelConv::run(inputs[0], outputs[0], weightsMat, biasvec, reluslope,
|
||||
kernel_size, strides, pads_begin, pads_end, dilations, activ.get(), ngroups, nstripes);
|
||||
runFastConv(inputs[0], outputs[0], fastConvImpl, nstripes, activ, reluslope, fusedAdd);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,381 +11,404 @@
|
||||
|
||||
#include "../../precomp.hpp"
|
||||
#include "fast_convolution.hpp"
|
||||
#include "../layers_common.hpp"
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
static void depthWiseBlock(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab,
|
||||
float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left,
|
||||
int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop,
|
||||
int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3)
|
||||
static void depthWiseBlockConv2D(const float* wptr,
|
||||
int kernel_h, int kernel_w,
|
||||
int stride_h, int stride_w,
|
||||
int dilation_h, int dilation_w,
|
||||
int pad_t, int pad_l,
|
||||
const float* biasptr, const float* relu,
|
||||
const float* inptr_,
|
||||
int height, int width,
|
||||
float* outptr_,
|
||||
int out_d, int outH, int outW)
|
||||
{
|
||||
#if CV_SIMD128
|
||||
const int VEC_NLANES = 4;
|
||||
v_float32x4 vminval = v_setall_f32(minval), vmaxval = v_setall_f32(maxval);
|
||||
const float w00_ = wptr[0], w01_ = wptr[1], w02_ = wptr[2],
|
||||
w10 = wptr[3], w11 = wptr[4], w12 = wptr[5],
|
||||
w20_ = wptr[6], w21_ = wptr[7], w22_ = wptr[8];
|
||||
int outW1 = min(outW, (width - dilation_w*(kernel_w - 1) + pad_l)/stride_w);
|
||||
float relu_coeff = relu ? relu[out_d] : 1.f, bias = biasptr[out_d];
|
||||
|
||||
v_float32x4 w0 = v_setall_f32(
|
||||
0.f), w1 = w0, w2 = w0, w3 = w0, w4 = w0, w5 = w0, w6 = w0, w7 = w0, w8 = w0, vbias = w0;
|
||||
if (useSIMD)
|
||||
for (int out_i = 0; out_i < outH; out_i++)
|
||||
{
|
||||
vbias = v_setall_f32(biasval);
|
||||
if (is3x3)
|
||||
int in_i = out_i * stride_h - pad_t, out_j = 0;
|
||||
const float* imgptr0 = inptr_ + in_i*width;
|
||||
const float* imgptr1 = imgptr0 + dilation_h*width;
|
||||
const float* imgptr2 = imgptr0 + (dilation_h*2)*width;
|
||||
float out, w00 = w00_, w01 = w01_, w02 = w02_;
|
||||
float w20 = w20_, w21 = w21_, w22 = w22_;
|
||||
if (in_i < 0)
|
||||
{
|
||||
w0 = v_setall_f32(weights[0]);
|
||||
w1 = v_setall_f32(weights[1]);
|
||||
w2 = v_setall_f32(weights[2]);
|
||||
w3 = v_setall_f32(weights[3]);
|
||||
w4 = v_setall_f32(weights[4]);
|
||||
w5 = v_setall_f32(weights[5]);
|
||||
w6 = v_setall_f32(weights[6]);
|
||||
w7 = v_setall_f32(weights[7]);
|
||||
w8 = v_setall_f32(weights[8]);
|
||||
w00 = w01 = w02 = 0.f;
|
||||
imgptr0 = imgptr1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
int dy0 = 1;
|
||||
for (int y0 = 0; y0 < H0; y0 += dy0, outptr += W0 * dy0)
|
||||
{
|
||||
#if CV_SIMD128
|
||||
dy0 = inner_ytop <= y0 && y0 + 3 < inner_ybottom && is3x3 && stride_y == 1 && dilation_y == 1
|
||||
? 3 : 1;
|
||||
#endif
|
||||
int x0 = 0, x1 = y0 >= inner_ytop && y0 < inner_ybottom ? inner_xleft : W0;
|
||||
int yi_ = y0 * stride_y - pad_top;
|
||||
|
||||
for (;;)
|
||||
else if (in_i + dilation_h*(kernel_h-1) >= height)
|
||||
{
|
||||
float s_0, s_1, s_2;
|
||||
if (dy0 == 3)
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
s_0 = s_1 = s_2 = biasval;
|
||||
for (int k = 0; k < ksize; k++)
|
||||
{
|
||||
int dy = yxtab[k * 2];
|
||||
int yi = yi_ + dy;
|
||||
int xi = xi_ + yxtab[k * 2 + 1];
|
||||
float w = weights[k];
|
||||
w20 = w21 = w22 = 0.f;
|
||||
imgptr2 = imgptr1;
|
||||
}
|
||||
|
||||
float* outptr = outptr_ + out_i*outW;
|
||||
if (pad_l > 0)
|
||||
{
|
||||
out = imgptr0[0]*w01 + imgptr0[dilation_w]*w02 +
|
||||
imgptr1[0]*w11 + imgptr1[dilation_w]*w12 +
|
||||
imgptr2[0]*w21 + imgptr2[dilation_w]*w22 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[0] = out;
|
||||
out_j = 1;
|
||||
}
|
||||
|
||||
if ((unsigned) xi < (unsigned) Wi)
|
||||
{
|
||||
s_0 += inptr[yi * Wi + xi] * w;
|
||||
s_1 += inptr[(yi + 1) * Wi + xi] * w;
|
||||
s_2 += inptr[(yi + 2) * Wi + xi] * w;
|
||||
}
|
||||
}
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
s_1 = std::min(std::max(s_1, minval), maxval);
|
||||
s_2 = std::min(std::max(s_2, minval), maxval);
|
||||
outptr[x0] = s_0;
|
||||
outptr[x0 + W0] = s_1;
|
||||
outptr[x0 + W0 * 2] = s_2;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
s_0 = biasval;
|
||||
for (int k = 0; k < ksize; k++) {
|
||||
int dy = yxtab[k * 2];
|
||||
int yi = yi_ + dy;
|
||||
int xi = xi_ + yxtab[k * 2 + 1];
|
||||
float w = weights[k];
|
||||
if (((unsigned) yi < (unsigned) Hi) & ((unsigned) xi < (unsigned) Wi))
|
||||
s_0 += inptr[yi * Wi + xi] * w;
|
||||
}
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
outptr[x0] = s_0;
|
||||
}
|
||||
}
|
||||
if (x0 == W0)
|
||||
break;
|
||||
x1 = inner_xright;
|
||||
#if CV_SIMD128
|
||||
if (useSIMD)
|
||||
const int VEC_NLANES = 4;
|
||||
v_float32x4 vw00 = v_setall_f32(w00);
|
||||
v_float32x4 vw01 = v_setall_f32(w01);
|
||||
v_float32x4 vw02 = v_setall_f32(w02);
|
||||
v_float32x4 vw10 = v_setall_f32(w10);
|
||||
v_float32x4 vw11 = v_setall_f32(w11);
|
||||
v_float32x4 vw12 = v_setall_f32(w12);
|
||||
v_float32x4 vw20 = v_setall_f32(w20);
|
||||
v_float32x4 vw21 = v_setall_f32(w21);
|
||||
v_float32x4 vw22 = v_setall_f32(w22);
|
||||
v_float32x4 z = v_setzero_f32();
|
||||
v_float32x4 vbias = v_setall_f32(bias);
|
||||
v_float32x4 vrc = v_setall_f32(relu_coeff);
|
||||
|
||||
if (stride_w == 1 || (stride_w == 2 && dilation_w == 1))
|
||||
{
|
||||
if( stride_w == 1 )
|
||||
{
|
||||
if (is3x3)
|
||||
for( ; out_j < outW1; out_j += VEC_NLANES )
|
||||
{
|
||||
if (dy0 == 3)
|
||||
if (out_j + VEC_NLANES > outW1)
|
||||
{
|
||||
for (; x0 <= x1 - VEC_NLANES; x0 += VEC_NLANES)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
|
||||
v_float32x4 s0, s1, s2;
|
||||
v_float32x4 x00 = v_load(inptr_xi);
|
||||
v_float32x4 x01 = v_load(inptr_xi + 1);
|
||||
v_float32x4 x02 = v_load(inptr_xi + 2);
|
||||
|
||||
v_float32x4 x10 = v_load(inptr_xi + Wi);
|
||||
v_float32x4 x11 = v_load(inptr_xi + Wi + 1);
|
||||
v_float32x4 x12 = v_load(inptr_xi + Wi + 2);
|
||||
|
||||
v_float32x4 x20 = v_load(inptr_xi + Wi * 2);
|
||||
v_float32x4 x21 = v_load(inptr_xi + Wi * 2 + 1);
|
||||
v_float32x4 x22 = v_load(inptr_xi + Wi * 2 + 2);
|
||||
|
||||
v_float32x4 x30 = v_load(inptr_xi + Wi * 3);
|
||||
v_float32x4 x31 = v_load(inptr_xi + Wi * 3 + 1);
|
||||
v_float32x4 x32 = v_load(inptr_xi + Wi * 3 + 2);
|
||||
|
||||
v_float32x4 x40 = v_load(inptr_xi + Wi * 4);
|
||||
v_float32x4 x41 = v_load(inptr_xi + Wi * 4 + 1);
|
||||
v_float32x4 x42 = v_load(inptr_xi + Wi * 4 + 2);
|
||||
|
||||
s0 = v_fma(x00, w0, vbias);
|
||||
s1 = v_fma(x10, w0, vbias);
|
||||
s2 = v_fma(x20, w0, vbias);
|
||||
|
||||
s0 = v_fma(x01, w1, s0);
|
||||
s1 = v_fma(x11, w1, s1);
|
||||
s2 = v_fma(x21, w1, s2);
|
||||
|
||||
s0 = v_fma(x02, w2, s0);
|
||||
s1 = v_fma(x12, w2, s1);
|
||||
s2 = v_fma(x22, w2, s2);
|
||||
|
||||
s0 = v_fma(x10, w3, s0);
|
||||
s1 = v_fma(x20, w3, s1);
|
||||
s2 = v_fma(x30, w3, s2);
|
||||
|
||||
s0 = v_fma(x11, w4, s0);
|
||||
s1 = v_fma(x21, w4, s1);
|
||||
s2 = v_fma(x31, w4, s2);
|
||||
|
||||
s0 = v_fma(x12, w5, s0);
|
||||
s1 = v_fma(x22, w5, s1);
|
||||
s2 = v_fma(x32, w5, s2);
|
||||
|
||||
s0 = v_fma(x20, w6, s0);
|
||||
s1 = v_fma(x30, w6, s1);
|
||||
s2 = v_fma(x40, w6, s2);
|
||||
|
||||
s0 = v_fma(x21, w7, s0);
|
||||
s1 = v_fma(x31, w7, s1);
|
||||
s2 = v_fma(x41, w7, s2);
|
||||
|
||||
s0 = v_fma(x22, w8, s0);
|
||||
s1 = v_fma(x32, w8, s1);
|
||||
s2 = v_fma(x42, w8, s2);
|
||||
|
||||
if (ifMinMaxAct)
|
||||
{
|
||||
s0 = v_min(v_max(s0, vminval), vmaxval);
|
||||
s1 = v_min(v_max(s1, vminval), vmaxval);
|
||||
s2 = v_min(v_max(s2, vminval), vmaxval);
|
||||
}
|
||||
|
||||
v_store(outptr + x0, s0);
|
||||
v_store(outptr + W0 + x0, s1);
|
||||
v_store(outptr + W0 * 2 + x0, s2);
|
||||
}
|
||||
if (out_j <= pad_l || outW1 - VEC_NLANES < 0)
|
||||
break;
|
||||
out_j = outW1 - VEC_NLANES;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 <= x1 - VEC_NLANES; x0 += VEC_NLANES)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
v_float32x4 s0 = v_fma(v_load(inptr_xi + ofstab[0]), w0, vbias);
|
||||
v_float32x4 s1 = v_load(inptr_xi + ofstab[1]) * w1;
|
||||
v_float32x4 s2 = v_load(inptr_xi + ofstab[2]) * w2;
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
v_float32x4 v00 = v_load(imgptr0 + in_j),
|
||||
v01 = v_load(imgptr0 + in_j + dilation_w),
|
||||
v02 = v_load(imgptr0 + in_j + dilation_w*2),
|
||||
v10 = v_load(imgptr1 + in_j),
|
||||
v11 = v_load(imgptr1 + in_j + dilation_w),
|
||||
v12 = v_load(imgptr1 + in_j + dilation_w*2),
|
||||
v20 = v_load(imgptr2 + in_j),
|
||||
v21 = v_load(imgptr2 + in_j + dilation_w),
|
||||
v22 = v_load(imgptr2 + in_j + dilation_w*2);
|
||||
|
||||
s0 = v_fma(v_load(inptr_xi + ofstab[3]), w3, s0);
|
||||
s1 = v_fma(v_load(inptr_xi + ofstab[4]), w4, s1);
|
||||
s2 = v_fma(v_load(inptr_xi + ofstab[5]), w5, s2);
|
||||
|
||||
s0 = v_fma(v_load(inptr_xi + ofstab[6]), w6, s0);
|
||||
s1 = v_fma(v_load(inptr_xi + ofstab[7]), w7, s1);
|
||||
s2 = v_fma(v_load(inptr_xi + ofstab[8]), w8, s2);
|
||||
|
||||
s0 = s0 + s1 + s2;
|
||||
if (ifMinMaxAct)
|
||||
s0 = v_min(v_max(s0, vminval), vmaxval);
|
||||
v_store(outptr + x0, s0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 <= x1 - VEC_NLANES; x0 += VEC_NLANES)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left, k = 0;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
v_float32x4 s0 = vbias;
|
||||
for (; k <= ksize - 4; k += 4)
|
||||
{
|
||||
v_float32x4 v0 = v_load(inptr_xi + ofstab[k]);
|
||||
v_float32x4 v1 = v_load(inptr_xi + ofstab[k + 1]);
|
||||
v_float32x4 v2 = v_load(inptr_xi + ofstab[k + 2]);
|
||||
v_float32x4 v3 = v_load(inptr_xi + ofstab[k + 3]);
|
||||
|
||||
v_float32x4 ww0 = v_setall_f32(weights[k]);
|
||||
v_float32x4 ww1 = v_setall_f32(weights[k+1]);
|
||||
v_float32x4 ww2 = v_setall_f32(weights[k+2]);
|
||||
v_float32x4 ww3 = v_setall_f32(weights[k+3]);
|
||||
|
||||
s0 = v_fma(v0, ww0, s0);
|
||||
s0 = v_fma(v1, ww1, s0);
|
||||
s0 = v_fma(v2, ww2, s0);
|
||||
s0 = v_fma(v3, ww3, s0);
|
||||
}
|
||||
for (; k < ksize; k++)
|
||||
s0 = v_fma(v_load(inptr_xi + ofstab[k]),
|
||||
v_setall_f32(weights[k]), s0);
|
||||
if (ifMinMaxAct)
|
||||
s0 = v_min(v_max(s0, vminval), vmaxval);
|
||||
v_store(outptr + x0, s0);
|
||||
}
|
||||
v_float32x4 vout = v00*vw00 + v01*vw01 + v02*vw02 +
|
||||
v10*vw10 + v11*vw11 + v12*vw12 +
|
||||
v20*vw20 + v21*vw21 + v22*vw22 + vbias;
|
||||
if (relu)
|
||||
vout = v_select(vout > z, vout, vout*vrc);
|
||||
v_store(outptr + out_j, vout);
|
||||
}
|
||||
}
|
||||
else // (stride_w == 2 && dilation_w == 1)
|
||||
{
|
||||
for( ; out_j < outW1; out_j += VEC_NLANES )
|
||||
{
|
||||
if (out_j + VEC_NLANES > outW1 && out_j > pad_l)
|
||||
{
|
||||
if (outW1 - VEC_NLANES < 0)
|
||||
break;
|
||||
out_j = outW1 - VEC_NLANES;
|
||||
}
|
||||
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
|
||||
v_float32x4 v00, v01, v02, v10, v11, v12, v20, v21, v22, unused;
|
||||
v_load_deinterleave(imgptr0 + in_j, v00, v01);
|
||||
v_load_deinterleave(imgptr0 + in_j + 2, v02, unused);
|
||||
v_load_deinterleave(imgptr1 + in_j, v10, v11);
|
||||
v_load_deinterleave(imgptr1 + in_j + 2, v12, unused);
|
||||
v_load_deinterleave(imgptr2 + in_j, v20, v21);
|
||||
v_load_deinterleave(imgptr2 + in_j + 2, v22, unused);
|
||||
|
||||
v_float32x4 vout = v00 * vw00 + v01 * vw01 + v02 * vw02 +
|
||||
v10 * vw10 + v11 * vw11 + v12 * vw12 +
|
||||
v20 * vw20 + v21 * vw21 + v22 * vw22 + vbias;
|
||||
|
||||
if (relu)
|
||||
vout = v_select(vout > z, vout, vout*vrc);
|
||||
v_store(outptr + out_j, vout);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (dy0 == 3)
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + W0 * yi_ + xi_;
|
||||
s_0 = s_1 = s_2 = biasval;
|
||||
for (int k = 0; k < ksize; k++)
|
||||
{
|
||||
int inp_ofs = ofstab[k];
|
||||
float w = weights[k];
|
||||
s_0 += inptr_xi[inp_ofs] * w;
|
||||
s_1 += inptr_xi[inp_ofs + Wi] * w;
|
||||
s_2 += inptr_xi[inp_ofs + Wi * 2] * w;
|
||||
}
|
||||
if (ifMinMaxAct)
|
||||
{
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
s_1 = std::min(std::max(s_1, minval), maxval);
|
||||
s_2 = std::min(std::max(s_2, minval), maxval);
|
||||
}
|
||||
|
||||
outptr[x0] = s_0;
|
||||
outptr[x0 + W0] = s_1;
|
||||
outptr[x0 + W0 * 2] = s_2;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
s_0 = biasval;
|
||||
for (int k = 0; k < ksize; k++)
|
||||
{
|
||||
s_0 += inptr_xi[ofstab[k]] * weights[k];
|
||||
}
|
||||
for (; out_j < outW1; out_j++)
|
||||
{
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
out = imgptr0[in_j]*w00 + imgptr0[in_j + dilation_w]*w01 + imgptr0[in_j + dilation_w*2]*w02 +
|
||||
imgptr1[in_j]*w10 + imgptr1[in_j + dilation_w]*w11 + imgptr1[in_j + dilation_w*2]*w12 +
|
||||
imgptr2[in_j]*w20 + imgptr2[in_j + dilation_w]*w21 + imgptr2[in_j + dilation_w*2]*w22 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[out_j] = out;
|
||||
}
|
||||
|
||||
if (ifMinMaxAct)
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
outptr[x0] = s_0;
|
||||
}
|
||||
for (; out_j < outW; out_j++ )
|
||||
{
|
||||
int in_j0 = out_j * stride_w - pad_l, in_j1 = in_j0 + dilation_w, in_j2 = in_j0 + dilation_w*2;
|
||||
float s0 = 1.f, s1 = 1.f, s2 = 1.f;
|
||||
if (in_j0 >= width)
|
||||
{
|
||||
in_j0 = 0;
|
||||
s0 = 0.f;
|
||||
}
|
||||
x1 = W0;
|
||||
if (in_j1 >= width)
|
||||
{
|
||||
in_j1 = 0;
|
||||
s1 = 0.f;
|
||||
}
|
||||
if (in_j2 >= width)
|
||||
{
|
||||
in_j2 = 0;
|
||||
s2 = 0.f;
|
||||
}
|
||||
out = imgptr0[in_j0]*w00*s0 + imgptr0[in_j1]*w01*s1 + imgptr0[in_j2]*w02*s2 +
|
||||
imgptr1[in_j0]*w10*s0 + imgptr1[in_j1]*w11*s1 + imgptr1[in_j2]*w12*s2 +
|
||||
imgptr2[in_j0]*w20*s0 + imgptr2[in_j1]*w21*s1 + imgptr2[in_j2]*w22*s2 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[out_j] = out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void runDepthwise(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct) {
|
||||
static void depthWiseBlockConv1D(const float* wptr,
|
||||
int kernel_w, int stride_w, int dilation_w, int pad_l,
|
||||
const float* biasptr, const float* relu,
|
||||
const float* inptr_, int width,
|
||||
float* outptr_,
|
||||
int out_d, int outW)
|
||||
{
|
||||
const float w00_ = wptr[0], w01_ = wptr[1], w02_ = wptr[2];
|
||||
int outW1 = min(outW, (width - dilation_w * (kernel_w - 1) + pad_l)/stride_w);
|
||||
float relu_coeff = relu ? relu[out_d] : 1.f, bias = biasptr[out_d];
|
||||
|
||||
int out_j = 0;
|
||||
const float* imgptr0 = inptr_;
|
||||
float out, w00 = w00_, w01 = w01_, w02 = w02_;
|
||||
float* outptr = outptr_;
|
||||
|
||||
if (pad_l > 0)
|
||||
{
|
||||
out = imgptr0[0]*w01 + imgptr0[dilation_w]*w02 + bias;
|
||||
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[0] = out;
|
||||
out_j = 1;
|
||||
}
|
||||
|
||||
#if CV_SIMD128
|
||||
const int VEC_NLANES = 4;
|
||||
v_float32x4 vw00 = v_setall_f32(w00);
|
||||
v_float32x4 vw01 = v_setall_f32(w01);
|
||||
v_float32x4 vw02 = v_setall_f32(w02);
|
||||
v_float32x4 z = v_setzero_f32();
|
||||
v_float32x4 vbias = v_setall_f32(bias);
|
||||
v_float32x4 vrc = v_setall_f32(relu_coeff);
|
||||
|
||||
if (stride_w == 1 || (stride_w == 2 && dilation_w == 1))
|
||||
{
|
||||
if( stride_w == 1 )
|
||||
{
|
||||
for( ; out_j < outW1; out_j += VEC_NLANES )
|
||||
{
|
||||
if (out_j + VEC_NLANES > outW1)
|
||||
{
|
||||
if (out_j <= pad_l || outW1 - VEC_NLANES < 0)
|
||||
break;
|
||||
out_j = outW1 - VEC_NLANES;
|
||||
}
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
v_float32x4 v00 = v_load(imgptr0 + in_j),
|
||||
v01 = v_load(imgptr0 + in_j + dilation_w),
|
||||
v02 = v_load(imgptr0 + in_j + dilation_w*2);
|
||||
|
||||
v_float32x4 vout = v00*vw00 + v01*vw01 + v02*vw02 + vbias;
|
||||
if (relu)
|
||||
vout = v_select(vout > z, vout, vout*vrc);
|
||||
v_store(outptr + out_j, vout);
|
||||
}
|
||||
}
|
||||
else // (stride_w == 2 && dilation_w == 1)
|
||||
{
|
||||
for( ; out_j < outW1; out_j += VEC_NLANES )
|
||||
{
|
||||
if (out_j + VEC_NLANES > outW1)
|
||||
{
|
||||
if (out_j <= pad_l || outW1 - VEC_NLANES < 0)
|
||||
break;
|
||||
out_j = outW1 - VEC_NLANES;
|
||||
}
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
|
||||
v_float32x4 v00, v01, v02, unused;
|
||||
v_load_deinterleave(imgptr0 + in_j, v00, v01);
|
||||
v_load_deinterleave(imgptr0 + in_j + 2, v02, unused);
|
||||
|
||||
v_float32x4 vout = v00 * vw00 + v01 * vw01 + v02 * vw02 + vbias;
|
||||
|
||||
if (relu)
|
||||
vout = v_select(vout > z, vout, vout*vrc);
|
||||
v_store(outptr + out_j, vout);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; out_j < outW1; out_j++)
|
||||
{
|
||||
int in_j = out_j * stride_w - pad_l;
|
||||
out = imgptr0[in_j]*w00 + imgptr0[in_j + dilation_w]*w01 + imgptr0[in_j + dilation_w*2]*w02 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[out_j] = out;
|
||||
}
|
||||
|
||||
for (; out_j < outW; out_j++ )
|
||||
{
|
||||
int in_j0 = out_j * stride_w - pad_l, in_j1 = in_j0 + dilation_w, in_j2 = in_j0 + dilation_w*2;
|
||||
float s0 = 1.f, s1 = 1.f, s2 = 1.f;
|
||||
if (in_j0 >= width)
|
||||
{
|
||||
in_j0 = 0;
|
||||
s0 = 0.f;
|
||||
}
|
||||
if (in_j1 >= width)
|
||||
{
|
||||
in_j1 = 0;
|
||||
s1 = 0.f;
|
||||
}
|
||||
if (in_j2 >= width)
|
||||
{
|
||||
in_j2 = 0;
|
||||
s2 = 0.f;
|
||||
}
|
||||
out = imgptr0[in_j0]*w00*s0 + imgptr0[in_j1]*w01*s1 + imgptr0[in_j2]*w02*s2 + bias;
|
||||
if (relu)
|
||||
out = out > 0.f ? out : out*relu_coeff;
|
||||
outptr[out_j] = out;
|
||||
}
|
||||
}
|
||||
|
||||
void runDepthwise(InputArray _input, OutputArray _output, const Ptr<FastConv>& conv, ActivationLayer* activ_,
|
||||
const std::vector<float>& reluslope)
|
||||
{
|
||||
Mat input = _input.getMat();
|
||||
Mat output = _output.getMat();
|
||||
MatShape inputShape = shape(input);
|
||||
MatShape outputShape = shape(output);
|
||||
CV_Assert(inputShape.size() == 4 && outputShape.size() == 4);
|
||||
|
||||
int N = inputShape[0], C = inputShape[1], Hi = inputShape[2], Wi = inputShape[3]; // [N, C, H, W]
|
||||
CV_Assert(inputShape.size() == 3 || inputShape.size() == 4);
|
||||
CV_Assert(inputShape.size() == outputShape.size());
|
||||
|
||||
int conv_dim = conv->conv_dim;
|
||||
CV_Assert((conv_dim == CONV_2D || conv_dim == CONV_1D) &&
|
||||
"DNN: Currently we do not support depth-wise for Convolution 3D!");
|
||||
|
||||
ActivationLayer* activ = reluslope.empty() ? activ_ : nullptr;
|
||||
int N = inputShape[0], C = inputShape[1];
|
||||
|
||||
int Hi = conv_dim == CONV_1D ? 1 : inputShape[inputShape.size() - 2];
|
||||
int Wi = inputShape[inputShape.size() - 1];
|
||||
|
||||
int K = conv->K, Hk = conv->Hk, Wk = conv->Wk;
|
||||
int H0 = outputShape[2], W0 = outputShape[3], ngroups = conv->ngroups;
|
||||
|
||||
int H0 = conv_dim == CONV_1D ? 1 : outputShape[outputShape.size() - 2];
|
||||
int W0 = outputShape[outputShape.size() - 1];
|
||||
int ngroups = conv->ngroups;
|
||||
|
||||
const size_t inp_planesize = (size_t) Hi * Wi;
|
||||
const size_t out_planesize = (size_t) H0 * W0;
|
||||
|
||||
CV_Assert(ngroups > 1 && ngroups == K && ngroups == C);
|
||||
|
||||
int stride_y = conv->stride_y, stride_x = conv->stride_x;
|
||||
int dilation_y = conv->dilation_y, dilation_x = conv->dilation_x;
|
||||
int stride_h = conv->stride_h, stride_w = conv->stride_w;
|
||||
int dilation_h = conv->dilation_h, dilation_w = conv->dilation_w;
|
||||
|
||||
int pad_top = conv->pad_top, pad_bottom = conv->pad_bottom;
|
||||
int pad_left = conv->pad_left, pad_right = conv->pad_right;
|
||||
|
||||
int VEC_NLANES = 4;
|
||||
#if CV_TRY_AVX2
|
||||
if (conv->useAVX2)
|
||||
VEC_NLANES = 8;
|
||||
#endif
|
||||
int ksize = Hk * Wk, padded_ksize = ((ksize + VEC_NLANES - 1) / VEC_NLANES) * VEC_NLANES;
|
||||
int ksize = Hk * Wk;
|
||||
|
||||
const int VEC_NLANES = 32;
|
||||
int padded_ksize = ((ksize + VEC_NLANES-1) / VEC_NLANES) * VEC_NLANES;
|
||||
|
||||
const float *inp = input.ptr<float>();
|
||||
float *out = output.ptr<float>();
|
||||
|
||||
std::vector<int> ofstab_(3 * padded_ksize, 0);
|
||||
#if CV_TRY_AVX2 || CV_TRY_AVX || CV_TRY_RVV
|
||||
// TODO: remove the following limitation, need change code in layers_common.simd.hpp.
|
||||
bool canRunOpt = Wi >= 16 + dilation_w*(Wk - 1);
|
||||
#endif
|
||||
std::vector<int> ofstab_(3 * ksize, 0);
|
||||
int *ofstab = ofstab_.data();
|
||||
int *yxtab = ofstab + padded_ksize;
|
||||
int *yxtab = ofstab + ksize;
|
||||
|
||||
for (int k = 0; k < padded_ksize; k++)
|
||||
for (int k = 0; k < ksize; k++)
|
||||
{
|
||||
int y = k < ksize ? k / Wk : 0;
|
||||
int x = k < ksize ? k % Wk : 0;
|
||||
int dy = y * dilation_y, dx = x * dilation_x;
|
||||
int dy = y * dilation_h, dx = x * dilation_w;
|
||||
yxtab[k * 2] = dy;
|
||||
yxtab[k * 2 + 1] = dx;
|
||||
ofstab[k] = dy * Wi + dx;
|
||||
}
|
||||
|
||||
const float *weights0 = conv->weightsBufPtr, *bias = conv->biasBuf.data();
|
||||
int inner_ytop = (pad_bottom + stride_y - 1) / stride_y, inner_ybottom = 3;
|
||||
int inner_xleft = (pad_left + stride_x - 1) / stride_x, inner_xright = 4;
|
||||
|
||||
const float* relu = reluslope.data();
|
||||
CV_Assert(ksize > 1 || (pad_left == 0 && pad_right == 0 && pad_top == 0 && pad_bottom == 0));
|
||||
|
||||
inner_xright = (Wi - (Wk - 1) * dilation_x + pad_left) / stride_x;
|
||||
inner_xright += inner_xright * stride_x - pad_left + (Wk - 1) * dilation_x < Wi;
|
||||
inner_ybottom = (Hi - (Hk - 1) * dilation_y + pad_top) / stride_y;
|
||||
inner_ybottom += inner_ybottom * stride_y - pad_top + (Hk - 1) * dilation_y < Hi;
|
||||
|
||||
if (inner_xleft >= inner_xright || inner_ytop >= inner_ybottom)
|
||||
{
|
||||
inner_xleft = W0;
|
||||
inner_ytop = H0;
|
||||
}
|
||||
|
||||
inner_ybottom = inner_ybottom < H0 ? inner_ybottom : H0;
|
||||
|
||||
bool useSIMD = stride_x == 1 && inner_xleft < W0;
|
||||
bool is3x3 = Hk == 3 && Wk == 3;
|
||||
|
||||
parallel_for_(Range(0, N * C), [&](const Range &r0) {
|
||||
for (int nc = r0.start; nc < r0.end; nc++)
|
||||
for (int nc = r0.start; nc < r0.end; nc++)
|
||||
{
|
||||
int c = nc % C;
|
||||
const float *inptr0 = inp + inp_planesize * nc;
|
||||
float *outptr0 = out + out_planesize * nc;
|
||||
|
||||
const float *weights = weights0 + c * padded_ksize;
|
||||
|
||||
if (conv_dim == CONV_2D)
|
||||
{
|
||||
int c = nc % C;
|
||||
const float *inptr = inp + inp_planesize * nc;
|
||||
float *outptr0 = out + out_planesize * nc;
|
||||
|
||||
float biasval = bias[c];
|
||||
const float *weights = weights0 + c * padded_ksize;
|
||||
|
||||
#if CV_TRY_AVX2
|
||||
if (conv->useAVX2)
|
||||
opt_AVX2::depthWiseBlock_AVX2(inptr, outptr0, weights, biasval, ofstab, yxtab, minval, maxval, Hi, Wi, H0, W0, ksize,
|
||||
pad_top, pad_left, dilation_y, stride_x, stride_y, inner_xleft, inner_xright, inner_ytop,
|
||||
inner_ybottom, ifMinMaxAct, useSIMD, is3x3);
|
||||
if(canRunOpt && conv->useAVX2)
|
||||
opt_AVX2::fastDepthwiseConv(weights, Hk, Wk, stride_h, stride_w, dilation_h, dilation_w,
|
||||
pad_top, pad_left, bias, relu, inptr0, Hi, Wi, outptr0, c, H0, W0);
|
||||
else
|
||||
#endif
|
||||
depthWiseBlock(inptr, outptr0, weights, biasval, ofstab, yxtab, minval, maxval, Hi, Wi, H0, W0, ksize,
|
||||
pad_top, pad_left, dilation_y, stride_x, stride_y, inner_xleft, inner_xright, inner_ytop,
|
||||
inner_ybottom, ifMinMaxAct, useSIMD, is3x3);
|
||||
|
||||
if (activ)
|
||||
activ->forwardSlice(outptr0, outptr0, (int) out_planesize, out_planesize, c, c+1);
|
||||
#if CV_TRY_AVX
|
||||
if(canRunOpt && conv->useAVX)
|
||||
opt_AVX::fastDepthwiseConv(weights, Hk, Wk, stride_h, stride_w, dilation_h, dilation_w,
|
||||
pad_top, pad_left, bias, relu, inptr0, Hi, Wi, outptr0, c, H0, W0);
|
||||
else
|
||||
#endif
|
||||
#if CV_TRY_RVV
|
||||
if(canRunOpt && conv->useRVV)
|
||||
opt_RVV::fastDepthwiseConv(weights, Hk, Wk, stride_h, stride_w, dilation_h, dilation_w,
|
||||
pad_top, pad_left, bias, relu, inptr0, Hi, Wi, outptr0, c, H0, W0);
|
||||
else
|
||||
#endif
|
||||
depthWiseBlockConv2D(weights, Hk, Wk, stride_h, stride_w, dilation_h, dilation_w,
|
||||
pad_top, pad_left, bias, relu, inptr0, Hi, Wi, outptr0, c, H0, W0);
|
||||
}
|
||||
});
|
||||
else // conv_dim == CONV_1D, spatial branch for depth-wise Conv1D.
|
||||
{
|
||||
depthWiseBlockConv1D(weights, Wk, stride_w, dilation_w, pad_left, bias, relu, inptr0, Wi, outptr0, c, W0);
|
||||
}
|
||||
|
||||
if (activ)
|
||||
activ->forwardSlice(outptr0, outptr0, (int) out_planesize, out_planesize, c, c+1);
|
||||
}});
|
||||
}
|
||||
|
||||
}} // namespace cv::dnn
|
||||
}} // namespace cv::dnn
|
||||
|
@ -6,9 +6,52 @@
|
||||
#include "fast_convolution.hpp"
|
||||
|
||||
namespace cv {
|
||||
namespace dnn {
|
||||
namespace opt_AVX2
|
||||
{
|
||||
#if CV_TRY_AVX2
|
||||
void convBlockMR1(int np, const float* a, const float* b, float *c, const float bias, bool init_c,
|
||||
const float minval, const float maxval, bool ifMinMaxAct)
|
||||
{
|
||||
#if CONV_NR == 24
|
||||
__m256 c0 = _mm256_set1_ps(bias), c1 = c0, c2 = c0;
|
||||
|
||||
for (int p = 0; p < np; p++, a++, b += CONV_NR)
|
||||
{
|
||||
__m256 a0 = _mm256_set1_ps(a[0]);
|
||||
__m256 b0 = _mm256_loadu_ps(b), b1 = _mm256_loadu_ps(b + 8), b2 = _mm256_loadu_ps(b + 16);
|
||||
|
||||
c0 = _mm256_fmadd_ps(b0, a0, c0);
|
||||
c1 = _mm256_fmadd_ps(b1, a0, c1);
|
||||
c2 = _mm256_fmadd_ps(b2, a0, c2);
|
||||
}
|
||||
|
||||
if (init_c)
|
||||
{
|
||||
c0 = _mm256_add_ps(_mm256_loadu_ps(c), c0);
|
||||
c1 = _mm256_add_ps(_mm256_loadu_ps(c + 8), c1);
|
||||
c2 = _mm256_add_ps(_mm256_loadu_ps(c + 16), c2);
|
||||
}
|
||||
|
||||
if (ifMinMaxAct)
|
||||
{
|
||||
__m256 vmax = _mm256_set1_ps(maxval);
|
||||
__m256 vmin = _mm256_set1_ps(minval);
|
||||
|
||||
c0 = _mm256_min_ps(_mm256_max_ps(c0, vmin), vmax);
|
||||
c1 = _mm256_min_ps(_mm256_max_ps(c1, vmin), vmax);
|
||||
c2 = _mm256_min_ps(_mm256_max_ps(c2, vmin), vmax);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(c, c0);
|
||||
_mm256_storeu_ps(c + 8, c1);
|
||||
_mm256_storeu_ps(c + 16, c2);
|
||||
_mm256_zeroupper();
|
||||
#else
|
||||
#error "unsupported CONV_NR in convBlockMR1."
|
||||
#endif
|
||||
}
|
||||
|
||||
void convBlock_AVX2(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
|
||||
{
|
||||
#if CONV_MR == 4 && CONV_NR == 24
|
||||
@ -73,291 +116,6 @@ void convBlock_AVX2(int np, const float* a, const float* b, float* c, int ldc, b
|
||||
#endif
|
||||
}
|
||||
|
||||
void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab,
|
||||
float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left,
|
||||
int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop,
|
||||
int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3)
|
||||
{
|
||||
const int VEC_NLANES = 8;
|
||||
__m256 vminval = _mm256_set1_ps(minval);
|
||||
__m256 vmaxval = _mm256_set1_ps(maxval);
|
||||
|
||||
__m256 w0 = _mm256_setzero_ps(),
|
||||
w1 = w0, w2 = w0, w3 = w0, w4 = w0, w5 = w0, w6 = w0, w7 = w0, w8 = w0, vbias = w0;
|
||||
|
||||
if (useSIMD)
|
||||
{
|
||||
vbias = _mm256_set1_ps(biasval);
|
||||
if (is3x3)
|
||||
{
|
||||
w0 = _mm256_set1_ps(weights[0]);
|
||||
w1 = _mm256_set1_ps(weights[1]);
|
||||
w2 = _mm256_set1_ps(weights[2]);
|
||||
w3 = _mm256_set1_ps(weights[3]);
|
||||
w4 = _mm256_set1_ps(weights[4]);
|
||||
w5 = _mm256_set1_ps(weights[5]);
|
||||
w6 = _mm256_set1_ps(weights[6]);
|
||||
w7 = _mm256_set1_ps(weights[7]);
|
||||
w8 = _mm256_set1_ps(weights[8]);
|
||||
}
|
||||
}
|
||||
|
||||
int dy0 = 1;
|
||||
for (int y0 = 0; y0 < H0; y0 += dy0, outptr += W0 * dy0)
|
||||
{
|
||||
dy0 = inner_ytop <= y0 && y0 + 3 < inner_ybottom && is3x3 && stride_y == 1 && dilation_y == 1
|
||||
? 3 : 1;
|
||||
|
||||
int x0 = 0, x1 = y0 >= inner_ytop && y0 < inner_ybottom ? inner_xleft : W0;
|
||||
int yi_ = y0 * stride_y - pad_top;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
float s_0, s_1, s_2;
|
||||
if (dy0 == 3)
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
s_0 = s_1 = s_2 = biasval;
|
||||
for (int k = 0; k < ksize; k++)
|
||||
{
|
||||
int dy = yxtab[k * 2];
|
||||
int yi = yi_ + dy;
|
||||
int xi = xi_ + yxtab[k * 2 + 1];
|
||||
float w = weights[k];
|
||||
|
||||
if ((unsigned) xi < (unsigned) Wi)
|
||||
{
|
||||
s_0 += inptr[yi * Wi + xi] * w;
|
||||
s_1 += inptr[(yi + 1) * Wi + xi] * w;
|
||||
s_2 += inptr[(yi + 2) * Wi + xi] * w;
|
||||
}
|
||||
}
|
||||
if (ifMinMaxAct)
|
||||
{
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
s_1 = std::min(std::max(s_1, minval), maxval);
|
||||
s_2 = std::min(std::max(s_2, minval), maxval);
|
||||
}
|
||||
|
||||
outptr[x0] = s_0;
|
||||
outptr[x0 + W0] = s_1;
|
||||
outptr[x0 + W0 * 2] = s_2;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
s_0 = biasval;
|
||||
for (int k = 0; k < ksize; k++) {
|
||||
int dy = yxtab[k * 2];
|
||||
int yi = yi_ + dy;
|
||||
int xi = xi_ + yxtab[k * 2 + 1];
|
||||
float w = weights[k];
|
||||
if (((unsigned) yi < (unsigned) Hi) & ((unsigned) xi < (unsigned) Wi))
|
||||
s_0 += inptr[yi * Wi + xi] * w;
|
||||
}
|
||||
if (ifMinMaxAct)
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
outptr[x0] = s_0;
|
||||
}
|
||||
}
|
||||
if (x0 == W0)
|
||||
break;
|
||||
x1 = inner_xright;
|
||||
|
||||
if (useSIMD)
|
||||
{
|
||||
if (is3x3)
|
||||
{
|
||||
if (dy0 == 3)
|
||||
{
|
||||
for (; x0 <= x1 - VEC_NLANES; x0 += VEC_NLANES)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
|
||||
__m256 s0, s1, s2;
|
||||
__m256 x00 = _mm256_loadu_ps(inptr_xi);
|
||||
__m256 x01 = _mm256_loadu_ps(inptr_xi + 1);
|
||||
__m256 x02 = _mm256_loadu_ps(inptr_xi + 2);
|
||||
|
||||
__m256 x10 = _mm256_loadu_ps(inptr_xi + Wi);
|
||||
__m256 x11 = _mm256_loadu_ps(inptr_xi + Wi + 1);
|
||||
__m256 x12 = _mm256_loadu_ps(inptr_xi + Wi + 2);
|
||||
|
||||
__m256 x20 = _mm256_loadu_ps(inptr_xi + Wi * 2);
|
||||
__m256 x21 = _mm256_loadu_ps(inptr_xi + Wi * 2 + 1);
|
||||
__m256 x22 = _mm256_loadu_ps(inptr_xi + Wi * 2 + 2);
|
||||
|
||||
__m256 x30 = _mm256_loadu_ps(inptr_xi + Wi * 3);
|
||||
__m256 x31 = _mm256_loadu_ps(inptr_xi + Wi * 3 + 1);
|
||||
__m256 x32 = _mm256_loadu_ps(inptr_xi + Wi * 3 + 2);
|
||||
|
||||
__m256 x40 = _mm256_loadu_ps(inptr_xi + Wi * 4);
|
||||
__m256 x41 = _mm256_loadu_ps(inptr_xi + Wi * 4 + 1);
|
||||
__m256 x42 = _mm256_loadu_ps(inptr_xi + Wi * 4 + 2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x00, w0, vbias);
|
||||
s1 = _mm256_fmadd_ps(x10, w0, vbias);
|
||||
s2 = _mm256_fmadd_ps(x20, w0, vbias);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x01, w1, s0);
|
||||
s1 = _mm256_fmadd_ps(x11, w1, s1);
|
||||
s2 = _mm256_fmadd_ps(x21, w1, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x02, w2, s0);
|
||||
s1 = _mm256_fmadd_ps(x12, w2, s1);
|
||||
s2 = _mm256_fmadd_ps(x22, w2, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x10, w3, s0);
|
||||
s1 = _mm256_fmadd_ps(x20, w3, s1);
|
||||
s2 = _mm256_fmadd_ps(x30, w3, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x11, w4, s0);
|
||||
s1 = _mm256_fmadd_ps(x21, w4, s1);
|
||||
s2 = _mm256_fmadd_ps(x31, w4, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x12, w5, s0);
|
||||
s1 = _mm256_fmadd_ps(x22, w5, s1);
|
||||
s2 = _mm256_fmadd_ps(x32, w5, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x20, w6, s0);
|
||||
s1 = _mm256_fmadd_ps(x30, w6, s1);
|
||||
s2 = _mm256_fmadd_ps(x40, w6, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x21, w7, s0);
|
||||
s1 = _mm256_fmadd_ps(x31, w7, s1);
|
||||
s2 = _mm256_fmadd_ps(x41, w7, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(x22, w8, s0);
|
||||
s1 = _mm256_fmadd_ps(x32, w8, s1);
|
||||
s2 = _mm256_fmadd_ps(x42, w8, s2);
|
||||
|
||||
if (ifMinMaxAct)
|
||||
{
|
||||
s0 = _mm256_min_ps(_mm256_max_ps(s0, vminval), vmaxval);
|
||||
s1 = _mm256_min_ps(_mm256_max_ps(s1, vminval), vmaxval);
|
||||
s2 = _mm256_min_ps(_mm256_max_ps(s2, vminval), vmaxval);
|
||||
}
|
||||
|
||||
_mm256_storeu_ps(outptr + x0, s0);
|
||||
_mm256_storeu_ps(outptr + W0 + x0, s1);
|
||||
_mm256_storeu_ps(outptr + W0 * 2 + x0, s2);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 <= x1 - VEC_NLANES; x0 += VEC_NLANES)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
__m256 s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[0]), w0, vbias);
|
||||
__m256 s1 = _mm256_mul_ps(_mm256_loadu_ps(inptr_xi + ofstab[1]), w1);
|
||||
__m256 s2 = _mm256_mul_ps(_mm256_loadu_ps(inptr_xi + ofstab[2]), w2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[3]), w3, s0);
|
||||
s1 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[4]), w4, s1);
|
||||
s2 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[5]), w5, s2);
|
||||
|
||||
s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[6]), w6, s0);
|
||||
s1 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[7]), w7, s1);
|
||||
s2 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[8]), w8, s2);
|
||||
|
||||
s0 = _mm256_add_ps(_mm256_add_ps(s0, s1), s2);
|
||||
|
||||
if (ifMinMaxAct)
|
||||
s0 = _mm256_min_ps(_mm256_max_ps(s0, vminval), vmaxval);
|
||||
_mm256_storeu_ps(outptr + x0, s0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 <= x1 - VEC_NLANES; x0 += VEC_NLANES)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left, k = 0;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
__m256 s0 = vbias;
|
||||
for (; k <= ksize - 4; k += 4)
|
||||
{
|
||||
__m256 v0 = _mm256_loadu_ps(inptr_xi + ofstab[k]);
|
||||
__m256 v1 = _mm256_loadu_ps(inptr_xi + ofstab[k + 1]);
|
||||
__m256 v2 = _mm256_loadu_ps(inptr_xi + ofstab[k + 2]);
|
||||
__m256 v3 = _mm256_loadu_ps(inptr_xi + ofstab[k + 3]);
|
||||
|
||||
__m256 ww0 = _mm256_set1_ps(weights[k]);
|
||||
__m256 ww1 = _mm256_set1_ps(weights[k+1]);
|
||||
__m256 ww2 = _mm256_set1_ps(weights[k+2]);
|
||||
__m256 ww3 = _mm256_set1_ps(weights[k+3]);
|
||||
|
||||
s0 = _mm256_fmadd_ps(v0, ww0, s0);
|
||||
s0 = _mm256_fmadd_ps(v1, ww1, s0);
|
||||
s0 = _mm256_fmadd_ps(v2, ww2, s0);
|
||||
s0 = _mm256_fmadd_ps(v3, ww3, s0);
|
||||
}
|
||||
for (; k < ksize; k++)
|
||||
s0 = _mm256_fmadd_ps(_mm256_loadu_ps(inptr_xi + ofstab[k]),
|
||||
_mm256_set1_ps(weights[k]), s0);
|
||||
|
||||
if (ifMinMaxAct)
|
||||
s0 = _mm256_min_ps(_mm256_max_ps(s0, vminval), vmaxval);
|
||||
_mm256_storeu_ps(outptr + x0, s0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dy0 == 3)
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + W0 * yi_ + xi_;
|
||||
s_0 = s_1 = s_2 = biasval;
|
||||
for (int k = 0; k < ksize; k++) {
|
||||
int inp_ofs = ofstab[k];
|
||||
float w = weights[k];
|
||||
s_0 += inptr_xi[inp_ofs] * w;
|
||||
s_1 += inptr_xi[inp_ofs + Wi] * w;
|
||||
s_2 += inptr_xi[inp_ofs + Wi * 2] * w;
|
||||
}
|
||||
if (ifMinMaxAct)
|
||||
{
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
s_1 = std::min(std::max(s_1, minval), maxval);
|
||||
s_2 = std::min(std::max(s_2, minval), maxval);
|
||||
}
|
||||
|
||||
outptr[x0] = s_0;
|
||||
outptr[x0 + W0] = s_1;
|
||||
outptr[x0 + W0 * 2] = s_2;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (; x0 < x1; x0++)
|
||||
{
|
||||
int xi_ = x0 * stride_x - pad_left;
|
||||
const float *inptr_xi = inptr + Wi * yi_ + xi_;
|
||||
s_0 = biasval;
|
||||
for (int k = 0; k < ksize; k++)
|
||||
{
|
||||
s_0 += inptr_xi[ofstab[k]] * weights[k];
|
||||
}
|
||||
if (ifMinMaxAct)
|
||||
s_0 = std::min(std::max(s_0, minval), maxval);
|
||||
outptr[x0] = s_0;
|
||||
}
|
||||
}
|
||||
x1 = W0;
|
||||
}
|
||||
}
|
||||
_mm256_zeroupper();
|
||||
}
|
||||
|
||||
void _fx_winograd_accum_f32(const float* inwptr, const float* wptr,
|
||||
float* outbuf, int Cg, int iblock)
|
||||
{
|
||||
@ -737,4 +495,5 @@ void _fx_winograd_AtXA_8x8_f32(const float* inptr, int inpstep,
|
||||
|
||||
#endif
|
||||
} // namespace opt_AVX2
|
||||
} // namespace dnn
|
||||
} // namespace cv
|
File diff suppressed because it is too large
Load Diff
@ -42,19 +42,20 @@ enum {
|
||||
|
||||
_FX_WINO_NATOMS_F32 = _FX_WINO_AREA / _FX_WINO_ATOM_F32, // for AVX2, it is 8, otherwise, it's 16.
|
||||
};
|
||||
enum { _FX_CONV_TYPE_GENERIC=0, _FX_CONV_TYPE_DEPTHWISE=1, _FX_CONV_TYPE_WINOGRAD3X3=2 };
|
||||
enum { _FX_CONV_TYPE_GENERIC=0, _FX_CONV_TYPE_DEPTHWISE=1, _FX_CONV_TYPE_WINOGRAD3X3=2, _FX_CONV_TYPE_DEPTHWISE_REMAIN=3 };
|
||||
enum { CONV_1D = 0, CONV_2D = 1, CONV_3D = 2 };
|
||||
#endif
|
||||
|
||||
namespace cv {
|
||||
namespace dnn {
|
||||
|
||||
struct FastConv2d
|
||||
struct FastConv
|
||||
{
|
||||
int ngroups;
|
||||
int K, C, Hk, Wk;
|
||||
int stride_y, stride_x;
|
||||
int dilation_y, dilation_x;
|
||||
int pad_top, pad_bottom, pad_left, pad_right;
|
||||
int K, C, Hk, Wk, Dk;
|
||||
int stride_h, stride_w, stride_d;
|
||||
int dilation_h, dilation_w, dilation_d;
|
||||
int pad_top, pad_bottom, pad_left, pad_right, pad_front, pad_behind;
|
||||
|
||||
std::vector<float> weightsBuf; // For generic Conv 2D
|
||||
float* weightsBufPtr;
|
||||
@ -62,57 +63,55 @@ struct FastConv2d
|
||||
float* weightsWinoBufPtr;
|
||||
std::vector<float> biasBuf;
|
||||
int conv_type;
|
||||
int conv_dim; // Flag for conv1d, conv2d, or conv3d.
|
||||
#if CV_SIMD128
|
||||
bool useSIMD128 = true;
|
||||
#else
|
||||
bool useSIMD128 = false;
|
||||
#endif
|
||||
|
||||
#if CV_TRY_AVX2
|
||||
bool useAVX2 = checkHardwareSupport(CPU_AVX2);
|
||||
#else
|
||||
bool useAVX2 = false;
|
||||
#endif
|
||||
|
||||
#if CV_NEON
|
||||
bool useNEON = checkHardwareSupport(CPU_NEON);
|
||||
#else
|
||||
bool useNEON = false;
|
||||
#endif
|
||||
|
||||
bool useAVX = checkHardwareSupport(CPU_AVX);
|
||||
bool useAVX2 = checkHardwareSupport(CPU_AVX2);
|
||||
bool useRVV = checkHardwareSupport(CPU_RVV);
|
||||
};
|
||||
|
||||
// return a FastConv2d instance.
|
||||
Ptr<FastConv2d> initFastConv2d(
|
||||
// return a FastConv instance.
|
||||
Ptr<FastConv> initFastConv(
|
||||
InputArray weightsMat,
|
||||
float* srcBias,
|
||||
int ngroups,
|
||||
int K, int C, int Hk, int Wk,
|
||||
int stride_x, int stride_y,
|
||||
int dilation_x, int dilation_y,
|
||||
int K, int C,
|
||||
const std::vector<size_t>& kernel_size,
|
||||
const std::vector<size_t>& strides,
|
||||
const std::vector<size_t>& dilations,
|
||||
const std::vector<size_t>& pads_begin,
|
||||
const std::vector<size_t>& pads_end,
|
||||
InputArray weightsMat,
|
||||
float* srcBias, bool useWinograd);
|
||||
int conv_dim,
|
||||
bool useWinograd);
|
||||
|
||||
// It contains different computing branches, like winograd, 1x1 conv.
|
||||
void runFastConv2d(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks,
|
||||
const Ptr<ActivationLayer>& actLayer, bool fusedAdd);
|
||||
void runFastConv(InputArray _input, OutputArray _output, const Ptr<FastConv>& conv, int ntasks,
|
||||
const Ptr<ActivationLayer>& actLayer, const std::vector<float>& reluslope, bool fusedAdd);
|
||||
|
||||
void runDepthwise(InputArray _input, OutputArray _output, const Ptr<FastConv2d>& conv, float minval, float maxval,
|
||||
ActivationLayer* activ, bool ifMinMaxAct);
|
||||
void runDepthwise(InputArray _input, OutputArray _output, const Ptr<FastConv>& conv, ActivationLayer* activ,
|
||||
const std::vector<float>& reluslope);
|
||||
|
||||
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv2d>& conv, int ntasks,
|
||||
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv>& conv, int ntasks,
|
||||
float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct);
|
||||
|
||||
} // namespace dnn
|
||||
|
||||
namespace opt_AVX2
|
||||
{
|
||||
#if CV_TRY_AVX2
|
||||
void convBlock_AVX2(int np, const float* a, const float* b, float* c, int ldc, bool init_c);
|
||||
|
||||
void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights, float biasval, int *ofstab, int *yxtab,
|
||||
float minval, float maxval, int Hi, int Wi, int H0, int W0, int ksize, int pad_top, int pad_left,
|
||||
int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop,
|
||||
int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3);
|
||||
void convBlockMR1(int np, const float* a, const float* b, float *c, const float bias, bool init_c, const float minval,
|
||||
const float maxval, bool ifMinMaxAct);
|
||||
|
||||
void _fx_winograd_accum_f32(const float* inwptr, const float* wptr, float* outbuf, int Cg, int iblock);
|
||||
void _fx_winograd_BtXB_8x8_f32(const float* inptr, int inpstep, float* outptr, int Cg);
|
||||
@ -122,6 +121,7 @@ void _fx_winograd_AtXA_8x8_f32(const float* inptr, int inpstep, float* bpptr, in
|
||||
#endif
|
||||
} // namespace opt_AVX2
|
||||
|
||||
} // namespace dnn
|
||||
} // namespace cv
|
||||
|
||||
#endif //OPENCV_FAST_CONVOLUTION_HPP
|
||||
|
@ -11,9 +11,132 @@
|
||||
namespace cv {
|
||||
namespace dnn {
|
||||
|
||||
void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
|
||||
static void convBlockMR1NoSIMD(int np, const float* a, const float* b, float *c, const float bias, bool init_c,
|
||||
const float minval, const float maxval, bool ifMinMaxAct, const int outLen)
|
||||
{
|
||||
std::vector<float> cbuffer(outLen, 0);
|
||||
float* cbuf = cbuffer.data();
|
||||
for( int p = 0; p < np; p++ )
|
||||
{
|
||||
float ai = a[p];
|
||||
for( int j = 0; j < outLen; j++ )
|
||||
cbuf[j] += b[CONV_NR*p + j] * ai;
|
||||
}
|
||||
|
||||
if (init_c)
|
||||
{
|
||||
for(int j = 0; j < outLen; j++)
|
||||
{
|
||||
c[j] += cbuf[j] + bias;
|
||||
if (ifMinMaxAct)
|
||||
c[j] = std::min(std::max(c[j], minval), maxval);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int j = 0; j < outLen; j++)
|
||||
{
|
||||
c[j] = cbuf[j] + bias;
|
||||
if (ifMinMaxAct)
|
||||
c[j] = std::min(std::max(c[j], minval), maxval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void convBlockMR1(int np, const float* a, const float* b, float *c, const float bias, bool init_c,
|
||||
const float minval, const float maxval, bool ifMinMaxAct, const int outLen)
|
||||
{
|
||||
#if CV_SIMD128
|
||||
// The outLen represents the valid output value in CONV_NR length.
|
||||
// When outLen is very small, we use the no-SIMD branch.
|
||||
const int CONV_NRby3 = CONV_NR/3;
|
||||
if (outLen > CONV_NRby3)
|
||||
{
|
||||
v_float32x4 c0 = v_setall_f32(bias), c1 = c0, c2 = c0; // CONV_NR == 12
|
||||
#if CONV_NR == 28 || CONV_NR == 24
|
||||
v_float32x4 c3 = c0, c4 = c0, c5 = c0;
|
||||
#endif
|
||||
#if CONV_NR == 28
|
||||
v_float32x4 c6 = c0;
|
||||
#endif
|
||||
for (int p = 0; p < np; p++, a++, b += CONV_NR)
|
||||
{
|
||||
v_float32x4 a0 = v_setall_f32(a[0]);
|
||||
v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8);
|
||||
#if CONV_NR == 28 || CONV_NR == 24
|
||||
v_float32x4 b3 = v_load(b + 12), b4 = v_load(b + 16), b5 = v_load(b + 20);
|
||||
#endif
|
||||
#if CONV_NR == 28
|
||||
v_float32x4 b6 = v_load(b + 24);
|
||||
#endif
|
||||
|
||||
c0 = v_fma(b0, a0, c0);
|
||||
c1 = v_fma(b1, a0, c1);
|
||||
c2 = v_fma(b2, a0, c2);
|
||||
#if CONV_NR == 28 || CONV_NR == 24
|
||||
c3 = v_fma(b3, a0, c3);
|
||||
c4 = v_fma(b4, a0, c4);
|
||||
c5 = v_fma(b5, a0, c5);
|
||||
#endif
|
||||
#if CONV_NR == 28
|
||||
c6 = v_fma(b6, a0, c6);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (init_c)
|
||||
{
|
||||
c0 += v_load(c);
|
||||
c1 += v_load(c + 4);
|
||||
c2 += v_load(c + 8);
|
||||
#if CONV_NR == 28 || CONV_NR == 24
|
||||
c3 += v_load(c + 12);
|
||||
c4 += v_load(c + 16);
|
||||
c5 += v_load(c + 20);
|
||||
#endif
|
||||
#if CONV_NR == 28
|
||||
c6 += v_load(c + 24);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (ifMinMaxAct)
|
||||
{
|
||||
v_float32x4 vmax = v_setall_f32(maxval), vmin = v_setall_f32(minval);
|
||||
c0 = v_min(v_max(c0, vmin), vmax);
|
||||
c1 = v_min(v_max(c1, vmin), vmax);
|
||||
c2 = v_min(v_max(c2, vmin), vmax);
|
||||
#if CONV_NR == 28 || CONV_NR == 24
|
||||
c3 = v_min(v_max(c3, vmin), vmax);
|
||||
c4 = v_min(v_max(c4, vmin), vmax);
|
||||
c5 = v_min(v_max(c5, vmin), vmax);
|
||||
#endif
|
||||
#if CONV_NR == 28
|
||||
c6 = v_min(v_max(c6, vmin), vmax);
|
||||
#endif
|
||||
}
|
||||
|
||||
v_store(c, c0);
|
||||
v_store(c + 4, c1);
|
||||
v_store(c + 8, c2);
|
||||
#if CONV_NR == 28 || CONV_NR == 24
|
||||
v_store(c + 12, c3);
|
||||
v_store(c + 16, c4);
|
||||
v_store(c + 20, c5);
|
||||
#endif
|
||||
#if CONV_NR == 28
|
||||
v_store(c + 24, c6);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
convBlockMR1NoSIMD(np, a, b, c, bias, init_c, minval, maxval, ifMinMaxAct, outLen);
|
||||
#else
|
||||
convBlockMR1NoSIMD(np, a, b, c, bias, init_c, minval, maxval, ifMinMaxAct, outLen);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CV_SIMD128
|
||||
#if CONV_MR == 4 && CONV_NR == 24
|
||||
static void convBlock4x24(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
|
||||
{
|
||||
#if CV_SIMD128 && CONV_MR == 4 && CONV_NR == 24
|
||||
v_float32x4 c0 = v_setzero_f32(), c1 = c0, c2 = c0, c3 = c0, c4 = c0, c5 = c0;
|
||||
v_float32x4 c6 = v_setzero_f32(), c7 = c6, c8 = c6, c9 = c6, c10 = c6, c11 = c6;
|
||||
v_float32x4 c12 = v_setzero_f32(), c13 = c12, c14 = c12, c15 = c12, c16 = c12, c17 = c12;
|
||||
@ -115,29 +238,156 @@ void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool i
|
||||
v_store(c + ldc * 3 + 12, c21);
|
||||
v_store(c + ldc * 3 + 16, c22);
|
||||
v_store(c + ldc * 3 + 20, c23);
|
||||
#else
|
||||
float cbuf[CONV_MR * CONV_NR];
|
||||
memset(cbuf, 0, sizeof(cbuf));
|
||||
}
|
||||
#endif
|
||||
|
||||
static void convBlock4x8(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
|
||||
{
|
||||
CV_Assert(CONV_NR >= 4);
|
||||
v_float32x4 c0 = v_setzero_f32(), c1 = c0, c2 = c0, c3 = c0;
|
||||
v_float32x4 c4 = c0, c5 = c0, c6 = c0, c7 = c0;
|
||||
|
||||
for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR)
|
||||
{
|
||||
v_float32x4 a0 = v_setall_f32(a[0]);
|
||||
v_float32x4 a1 = v_setall_f32(a[1]);
|
||||
v_float32x4 a2 = v_setall_f32(a[2]);
|
||||
v_float32x4 a3 = v_setall_f32(a[3]);
|
||||
|
||||
v_float32x4 b0 = v_load(b), b1 = v_load(b + 4);
|
||||
|
||||
c0 = v_fma(b0, a0, c0);
|
||||
c1 = v_fma(b1, a0, c1);
|
||||
|
||||
c2 = v_fma(b0, a1, c2);
|
||||
c3 = v_fma(b1, a1, c3);
|
||||
|
||||
c4 = v_fma(b0, a2, c4);
|
||||
c5 = v_fma(b1, a2, c5);
|
||||
|
||||
c6 = v_fma(b0, a3, c6);
|
||||
c7 = v_fma(b1, a3, c7);
|
||||
}
|
||||
|
||||
if (!init_c)
|
||||
{
|
||||
c0 += v_load(c);
|
||||
c1 += v_load(c + 4);
|
||||
|
||||
c2 += v_load(c + ldc);
|
||||
c3 += v_load(c + ldc + 4);
|
||||
|
||||
c4 += v_load(c + ldc*2);
|
||||
c5 += v_load(c + ldc*2 + 4);
|
||||
|
||||
c6 += v_load(c + ldc*3);
|
||||
c7 += v_load(c + ldc*3 + 4);
|
||||
}
|
||||
|
||||
v_store(c, c0);
|
||||
v_store(c + 4, c1);
|
||||
v_store(c + ldc, c2);
|
||||
v_store(c + ldc + 4, c3);
|
||||
v_store(c + ldc * 2, c4);
|
||||
v_store(c + ldc * 2 + 4, c5);
|
||||
v_store(c + ldc * 3, c6);
|
||||
v_store(c + ldc * 3 + 4, c7);
|
||||
}
|
||||
|
||||
static void convBlock4x4(int np, const float* a, const float* b, float* c, int ldc, bool init_c)
|
||||
{
|
||||
CV_Assert(CONV_NR >= 4);
|
||||
v_float32x4 c0 = v_setzero_f32(), c1 = c0, c2 = c0, c3 = c0;
|
||||
|
||||
for (int p = 0; p < np; p++, a += CONV_MR, b += CONV_NR)
|
||||
{
|
||||
v_float32x4 a0 = v_setall_f32(a[0]);
|
||||
v_float32x4 a1 = v_setall_f32(a[1]);
|
||||
v_float32x4 a2 = v_setall_f32(a[2]);
|
||||
v_float32x4 a3 = v_setall_f32(a[3]);
|
||||
|
||||
v_float32x4 b0 = v_load(b);
|
||||
|
||||
c0 = v_fma(b0, a0, c0);
|
||||
c1 = v_fma(b0, a1, c1);
|
||||
c2 = v_fma(b0, a2, c2);
|
||||
c3 = v_fma(b0, a3, c3);
|
||||
}
|
||||
|
||||
if (!init_c)
|
||||
{
|
||||
c0 += v_load(c);
|
||||
c1 += v_load(c + ldc);
|
||||
c2 += v_load(c + ldc*2);
|
||||
c3 += v_load(c + ldc*3);
|
||||
}
|
||||
|
||||
v_store(c, c0);
|
||||
v_store(c + ldc, c1);
|
||||
v_store(c + ldc * 2, c2);
|
||||
v_store(c + ldc * 3, c3);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void convBlockNoSIMD(int np, const float* a, const float* b, float* c, int ldc, bool init_c, const int outLen)
|
||||
{
|
||||
std::vector<float> cbuffer(CONV_MR * outLen, 0);
|
||||
float* cbuf = cbuffer.data();
|
||||
for( int p = 0; p < np; p++ )
|
||||
{
|
||||
for( int i = 0; i < CONV_MR; i++ )
|
||||
{
|
||||
float ai = a[CONV_MR*p + i];
|
||||
for( int j = 0; j < CONV_NR; j++ )
|
||||
cbuf[i * CONV_NR+j] += b[CONV_NR*p + j] * ai;
|
||||
for( int j = 0; j < outLen; j++ )
|
||||
cbuf[i * outLen+j] += b[CONV_NR*p + j] * ai;
|
||||
}
|
||||
}
|
||||
if (!init_c) {
|
||||
for(int i = 0; i < CONV_MR; i++) {
|
||||
for(int j = 0; j < CONV_NR; j++)
|
||||
c[i*ldc + j] += cbuf[i*CONV_NR + j];
|
||||
}
|
||||
} else {
|
||||
for(int i = 0; i < CONV_MR; i++) {
|
||||
for(int j = 0; j < CONV_NR; j++)
|
||||
c[i*ldc + j] = cbuf[i*CONV_NR + j];
|
||||
|
||||
if (!init_c)
|
||||
{
|
||||
for(int i = 0; i < CONV_MR; i++)
|
||||
{
|
||||
for(int j = 0; j < outLen; j++)
|
||||
c[i*ldc + j] += cbuf[i*outLen + j];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < CONV_MR; i++)
|
||||
{
|
||||
for(int j = 0; j < outLen; j++)
|
||||
c[i*ldc + j] = cbuf[i*outLen + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void convBlock(int np, const float* a, const float* b, float* c, int ldc, bool init_c, const int outLen)
|
||||
{
|
||||
// The possible outLen range is [24, 8~1].
|
||||
#if CV_SIMD128
|
||||
#if CONV_MR == 4 && CONV_NR == 24
|
||||
const int CONV_NRby3 = CONV_NR/3;
|
||||
if (outLen > CONV_NRby3)
|
||||
{
|
||||
convBlock4x24(np, a, b, c, ldc, init_c);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (outLen <= 8 && outLen > 4)
|
||||
{
|
||||
convBlock4x8(np, a, b, c, ldc, init_c);
|
||||
return;
|
||||
}
|
||||
|
||||
if (outLen <= 4 && outLen > 1)
|
||||
{
|
||||
convBlock4x4(np, a, b, c, ldc, init_c);
|
||||
return;
|
||||
}
|
||||
convBlockNoSIMD(np, a, b, c, ldc, init_c, outLen);
|
||||
#else
|
||||
convBlockNoSIMD(np, a, b, c, ldc, init_c, outLen);
|
||||
#endif
|
||||
}
|
||||
} // namespace dnn
|
||||
|
@ -920,7 +920,7 @@ _fx_winograd_AtXA_8x8_f32(const float* inptr, int inpstep,
|
||||
#endif
|
||||
}
|
||||
|
||||
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv2d>& conv,
|
||||
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv>& conv,
|
||||
int ntasks, float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct)
|
||||
{
|
||||
Mat input = _input.getMat();
|
||||
@ -1144,7 +1144,7 @@ int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _outpu
|
||||
|
||||
#else
|
||||
|
||||
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv2d>& conv,
|
||||
int runWinograd63(InputArray _input, InputArray _fusedAddMat, OutputArray _output, const Ptr<FastConv>& conv,
|
||||
int ntasks, float minval, float maxval, ActivationLayer* activ, bool ifMinMaxAct)
|
||||
{
|
||||
return 0;
|
||||
|
@ -46,10 +46,6 @@ namespace cv {
|
||||
namespace dnn {
|
||||
CV_CPU_OPTIMIZATION_NAMESPACE_BEGIN
|
||||
|
||||
void fastConv( const float* weights, size_t wstep, const float* bias,
|
||||
const float* rowbuf, float* output, const int* outShape,
|
||||
int blockSize, int vecsize, int vecsize_aligned,
|
||||
const float* relu, bool initOutput );
|
||||
void fastDepthwiseConv( const float* weights,
|
||||
int kernel_h, int kernel_w,
|
||||
int stride_h, int stride_w,
|
||||
@ -74,305 +70,6 @@ void fastGEMM( const float* aptr, size_t astep, const float* bptr,
|
||||
#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b))
|
||||
#endif
|
||||
|
||||
enum { FASCONV_BASE_VECSZ = 4 };
|
||||
|
||||
void fastConv( const float* weights, size_t wstep, const float* bias,
|
||||
const float* rowbuf, float* output, const int* outShape,
|
||||
int blockSize, int vecsize, int vecsize_aligned,
|
||||
const float* relu, bool initOutput )
|
||||
{
|
||||
CV_Assert(isAligned<32>(weights));
|
||||
|
||||
int outCn = outShape[1];
|
||||
size_t outPlaneSize = outShape[2]*outShape[3];
|
||||
float r0 = 1.f, r1 = 1.f, r2 = 1.f;
|
||||
__m128 vr0 = _mm_set1_ps(1.f), vr1 = vr0, vr2 = vr0, z = _mm_setzero_ps();
|
||||
int CV_DECL_ALIGNED(16) maskbuf[FASCONV_BASE_VECSZ] = {0};
|
||||
int rsz = blockSize % FASCONV_BASE_VECSZ;
|
||||
for( int i = 0; i < rsz; i++ )
|
||||
maskbuf[FASCONV_BASE_VECSZ - i - 1] = -1;
|
||||
__m128 mask = _mm_loadu_ps((const float*)maskbuf);
|
||||
|
||||
// now compute dot product of the weights
|
||||
// and im2row-transformed part of the tensor
|
||||
for( int i = 0; i < outCn; i += 3 )
|
||||
{
|
||||
const float* wptr0 = weights + i*wstep;
|
||||
const float* wptr1 = wptr0 + wstep;
|
||||
const float* wptr2 = wptr1 + wstep;
|
||||
float* outptr0 = output + i*outPlaneSize;
|
||||
float* outptr1 = outptr0 + outPlaneSize;
|
||||
float* outptr2 = outptr1 + outPlaneSize;
|
||||
float bias0 = bias[i], bias1 = bias[i+1], bias2 = bias[i+2];
|
||||
|
||||
if( i+2 >= outCn )
|
||||
{
|
||||
wptr2 = wptr1;
|
||||
outptr2 = outptr1;
|
||||
bias2 = bias1;
|
||||
if( i+1 >= outCn )
|
||||
{
|
||||
wptr2 = wptr1 = wptr0;
|
||||
outptr2 = outptr1 = outptr0;
|
||||
bias2 = bias1 = bias0;
|
||||
}
|
||||
}
|
||||
|
||||
if( relu )
|
||||
{
|
||||
r0 = relu[i]; r1 = relu[i+1]; r2 = relu[i+2];
|
||||
if( i+2 >= outCn )
|
||||
{
|
||||
r2 = r1;
|
||||
if( i+1 >= outCn )
|
||||
r2 = r1 = r0;
|
||||
}
|
||||
vr0 = _mm_set1_ps(r0);
|
||||
vr1 = _mm_set1_ps(r1);
|
||||
vr2 = _mm_set1_ps(r2);
|
||||
}
|
||||
|
||||
int j = 0;
|
||||
for( ; j < blockSize; j += FASCONV_BASE_VECSZ )
|
||||
{
|
||||
bool tail = false;
|
||||
if (j + FASCONV_BASE_VECSZ > blockSize)
|
||||
{
|
||||
if (j == 0)
|
||||
break;
|
||||
j = blockSize - FASCONV_BASE_VECSZ;
|
||||
tail = true;
|
||||
}
|
||||
int k = 0;
|
||||
const float* rptr = rowbuf + j*vecsize_aligned;
|
||||
|
||||
__m256 vs00 = _mm256_setzero_ps(), vs01 = _mm256_setzero_ps(),
|
||||
vs02 = _mm256_setzero_ps(), vs03 = _mm256_setzero_ps(),
|
||||
vs10 = _mm256_setzero_ps(), vs11 = _mm256_setzero_ps(),
|
||||
vs12 = _mm256_setzero_ps(), vs13 = _mm256_setzero_ps(),
|
||||
vs20 = _mm256_setzero_ps(), vs21 = _mm256_setzero_ps(),
|
||||
vs22 = _mm256_setzero_ps(), vs23 = _mm256_setzero_ps();
|
||||
|
||||
#if CV_AVX512_SKX // AVX512VL is necessary to avoid register spilling
|
||||
if (vecsize >= 32)
|
||||
{
|
||||
__m512 vs00_5 = _mm512_setzero_ps(), vs01_5 = _mm512_setzero_ps(),
|
||||
vs02_5 = _mm512_setzero_ps(), vs03_5 = _mm512_setzero_ps(),
|
||||
vs10_5 = _mm512_setzero_ps(), vs11_5 = _mm512_setzero_ps(),
|
||||
vs12_5 = _mm512_setzero_ps(), vs13_5 = _mm512_setzero_ps(),
|
||||
vs20_5 = _mm512_setzero_ps(), vs21_5 = _mm512_setzero_ps(),
|
||||
vs22_5 = _mm512_setzero_ps(), vs23_5 = _mm512_setzero_ps();
|
||||
|
||||
for (; k <= vecsize - 16; k += 16, rptr += 16)
|
||||
{
|
||||
__m512 w0 = _mm512_loadu_ps(wptr0 + k);
|
||||
__m512 w1 = _mm512_loadu_ps(wptr1 + k);
|
||||
__m512 w2 = _mm512_loadu_ps(wptr2 + k);
|
||||
__m512 r0 = _mm512_loadu_ps(rptr);
|
||||
|
||||
vs00_5 = _mm512_fmadd_ps(w0, r0, vs00_5);
|
||||
vs10_5 = _mm512_fmadd_ps(w1, r0, vs10_5);
|
||||
vs20_5 = _mm512_fmadd_ps(w2, r0, vs20_5);
|
||||
|
||||
r0 = _mm512_loadu_ps(rptr + vecsize_aligned);
|
||||
vs01_5 = _mm512_fmadd_ps(w0, r0, vs01_5);
|
||||
vs11_5 = _mm512_fmadd_ps(w1, r0, vs11_5);
|
||||
vs21_5 = _mm512_fmadd_ps(w2, r0, vs21_5);
|
||||
|
||||
r0 = _mm512_loadu_ps(rptr + vecsize_aligned*2);
|
||||
vs02_5 = _mm512_fmadd_ps(w0, r0, vs02_5);
|
||||
vs12_5 = _mm512_fmadd_ps(w1, r0, vs12_5);
|
||||
vs22_5 = _mm512_fmadd_ps(w2, r0, vs22_5);
|
||||
|
||||
r0 = _mm512_loadu_ps(rptr + vecsize_aligned*3);
|
||||
vs03_5 = _mm512_fmadd_ps(w0, r0, vs03_5);
|
||||
vs13_5 = _mm512_fmadd_ps(w1, r0, vs13_5);
|
||||
vs23_5 = _mm512_fmadd_ps(w2, r0, vs23_5);
|
||||
}
|
||||
/*
|
||||
* now fold the 512 bit accumulator vectors into 256 bit vectors so that the AVX2 code can finish
|
||||
* the tail of the vector
|
||||
*/
|
||||
vs00 = _mm256_add_ps( _mm512_extractf32x8_ps(vs00_5, 0), _mm512_extractf32x8_ps(vs00_5, 1));
|
||||
vs10 = _mm256_add_ps( _mm512_extractf32x8_ps(vs10_5, 0), _mm512_extractf32x8_ps(vs10_5, 1));
|
||||
vs20 = _mm256_add_ps( _mm512_extractf32x8_ps(vs20_5, 0), _mm512_extractf32x8_ps(vs20_5, 1));
|
||||
|
||||
vs01 = _mm256_add_ps( _mm512_extractf32x8_ps(vs01_5, 0), _mm512_extractf32x8_ps(vs01_5, 1));
|
||||
vs11 = _mm256_add_ps( _mm512_extractf32x8_ps(vs11_5, 0), _mm512_extractf32x8_ps(vs11_5, 1));
|
||||
vs21 = _mm256_add_ps( _mm512_extractf32x8_ps(vs21_5, 0), _mm512_extractf32x8_ps(vs21_5, 1));
|
||||
|
||||
vs02 = _mm256_add_ps( _mm512_extractf32x8_ps(vs02_5, 0), _mm512_extractf32x8_ps(vs02_5, 1));
|
||||
vs12 = _mm256_add_ps( _mm512_extractf32x8_ps(vs12_5, 0), _mm512_extractf32x8_ps(vs12_5, 1));
|
||||
vs22 = _mm256_add_ps( _mm512_extractf32x8_ps(vs22_5, 0), _mm512_extractf32x8_ps(vs22_5, 1));
|
||||
|
||||
vs03 = _mm256_add_ps( _mm512_extractf32x8_ps(vs03_5, 0), _mm512_extractf32x8_ps(vs03_5, 1));
|
||||
vs13 = _mm256_add_ps( _mm512_extractf32x8_ps(vs13_5, 0), _mm512_extractf32x8_ps(vs13_5, 1));
|
||||
vs23 = _mm256_add_ps( _mm512_extractf32x8_ps(vs23_5, 0), _mm512_extractf32x8_ps(vs23_5, 1));
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; k < vecsize; k += 8, rptr += 8 )
|
||||
{
|
||||
__m256 w0 = _mm256_load_ps(wptr0 + k);
|
||||
__m256 w1 = _mm256_load_ps(wptr1 + k);
|
||||
__m256 w2 = _mm256_load_ps(wptr2 + k);
|
||||
__m256 r0 = _mm256_load_ps(rptr);
|
||||
|
||||
vs00 = _mm256_fmadd_ps(w0, r0, vs00);
|
||||
vs10 = _mm256_fmadd_ps(w1, r0, vs10);
|
||||
vs20 = _mm256_fmadd_ps(w2, r0, vs20);
|
||||
|
||||
r0 = _mm256_load_ps(rptr + vecsize_aligned);
|
||||
vs01 = _mm256_fmadd_ps(w0, r0, vs01);
|
||||
vs11 = _mm256_fmadd_ps(w1, r0, vs11);
|
||||
vs21 = _mm256_fmadd_ps(w2, r0, vs21);
|
||||
|
||||
r0 = _mm256_load_ps(rptr + vecsize_aligned*2);
|
||||
vs02 = _mm256_fmadd_ps(w0, r0, vs02);
|
||||
vs12 = _mm256_fmadd_ps(w1, r0, vs12);
|
||||
vs22 = _mm256_fmadd_ps(w2, r0, vs22);
|
||||
|
||||
r0 = _mm256_load_ps(rptr + vecsize_aligned*3);
|
||||
vs03 = _mm256_fmadd_ps(w0, r0, vs03);
|
||||
vs13 = _mm256_fmadd_ps(w1, r0, vs13);
|
||||
vs23 = _mm256_fmadd_ps(w2, r0, vs23);
|
||||
}
|
||||
|
||||
__m256 t0 = _mm256_hadd_ps(_mm256_hadd_ps(vs00, vs01), _mm256_hadd_ps(vs02, vs03));
|
||||
__m256 t1 = _mm256_hadd_ps(_mm256_hadd_ps(vs10, vs11), _mm256_hadd_ps(vs12, vs13));
|
||||
__m256 t2 = _mm256_hadd_ps(_mm256_hadd_ps(vs20, vs21), _mm256_hadd_ps(vs22, vs23));
|
||||
|
||||
t0 = _mm256_add_ps(t0, _mm256_permute2f128_ps(t0, t0, 1));
|
||||
t1 = _mm256_add_ps(t1, _mm256_permute2f128_ps(t1, t1, 1));
|
||||
t2 = _mm256_add_ps(t2, _mm256_permute2f128_ps(t2, t2, 1));
|
||||
|
||||
__m128 s0, s1, s2;
|
||||
|
||||
if( initOutput )
|
||||
{
|
||||
s0 = _mm_set1_ps(bias0);
|
||||
s1 = _mm_set1_ps(bias1);
|
||||
s2 = _mm_set1_ps(bias2);
|
||||
}
|
||||
else
|
||||
{
|
||||
s0 = _mm_loadu_ps(outptr0 + j);
|
||||
s1 = _mm_loadu_ps(outptr1 + j);
|
||||
s2 = _mm_loadu_ps(outptr2 + j);
|
||||
}
|
||||
|
||||
s0 = _mm_add_ps(s0, _mm256_castps256_ps128(t0));
|
||||
s1 = _mm_add_ps(s1, _mm256_castps256_ps128(t1));
|
||||
s2 = _mm_add_ps(s2, _mm256_castps256_ps128(t2));
|
||||
|
||||
if( relu )
|
||||
{
|
||||
__m128 m0 = _mm_cmp_ps(s0, z, _CMP_GT_OS);
|
||||
__m128 m1 = _mm_cmp_ps(s1, z, _CMP_GT_OS);
|
||||
__m128 m2 = _mm_cmp_ps(s2, z, _CMP_GT_OS);
|
||||
s0 = _mm_blendv_ps(_mm_mul_ps(s0, vr0), s0, m0);
|
||||
s1 = _mm_blendv_ps(_mm_mul_ps(s1, vr1), s1, m1);
|
||||
s2 = _mm_blendv_ps(_mm_mul_ps(s2, vr2), s2, m2);
|
||||
}
|
||||
|
||||
if( tail )
|
||||
{
|
||||
s0 = _mm_blendv_ps(_mm_loadu_ps(outptr0 + j), s0, mask);
|
||||
s1 = _mm_blendv_ps(_mm_loadu_ps(outptr1 + j), s1, mask);
|
||||
s2 = _mm_blendv_ps(_mm_loadu_ps(outptr2 + j), s2, mask);
|
||||
}
|
||||
|
||||
_mm_storeu_ps(outptr0 + j, s0);
|
||||
_mm_storeu_ps(outptr1 + j, s1);
|
||||
_mm_storeu_ps(outptr2 + j, s2);
|
||||
}
|
||||
|
||||
for( ; j <= blockSize - 2; j += 2 )
|
||||
{
|
||||
const float* rptr0 = rowbuf + j*vecsize_aligned;
|
||||
const float* rptr1 = rowbuf + (j+1)*vecsize_aligned;
|
||||
float s00, s01, s10, s11, s20, s21;
|
||||
|
||||
if( initOutput )
|
||||
{
|
||||
s00 = s01 = bias0;
|
||||
s10 = s11 = bias1;
|
||||
s20 = s21 = bias2;
|
||||
}
|
||||
else
|
||||
{
|
||||
s00 = outptr0[j]; s01 = outptr0[j+1];
|
||||
s10 = outptr1[j]; s11 = outptr1[j+1];
|
||||
s20 = outptr2[j]; s21 = outptr2[j+1];
|
||||
}
|
||||
|
||||
for( int k = 0; k < vecsize; k++ )
|
||||
{
|
||||
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k];
|
||||
float r = rptr0[k];
|
||||
s00 += w0*r; s10 += w1*r; s20 += w2*r;
|
||||
r = rptr1[k];
|
||||
s01 += w0*r; s11 += w1*r; s21 += w2*r;
|
||||
}
|
||||
|
||||
if( relu )
|
||||
{
|
||||
s00 = s00 > 0.f ? s00 : s00*r0;
|
||||
s01 = s01 > 0.f ? s01 : s01*r0;
|
||||
s10 = s10 > 0.f ? s10 : s10*r1;
|
||||
s11 = s11 > 0.f ? s11 : s11*r1;
|
||||
s20 = s20 > 0.f ? s20 : s20*r2;
|
||||
s21 = s21 > 0.f ? s21 : s21*r2;
|
||||
}
|
||||
|
||||
outptr0[j] = s00;
|
||||
outptr0[j+1] = s01;
|
||||
outptr1[j] = s10;
|
||||
outptr1[j+1] = s11;
|
||||
outptr2[j] = s20;
|
||||
outptr2[j+1] = s21;
|
||||
}
|
||||
|
||||
for( ; j < blockSize; j++ )
|
||||
{
|
||||
const float* rptr0 = rowbuf + j*vecsize_aligned;
|
||||
float s00, s10, s20;
|
||||
|
||||
if( initOutput )
|
||||
{
|
||||
s00 = bias0;
|
||||
s10 = bias1;
|
||||
s20 = bias2;
|
||||
}
|
||||
else
|
||||
{
|
||||
s00 = outptr0[j];
|
||||
s10 = outptr1[j];
|
||||
s20 = outptr2[j];
|
||||
}
|
||||
|
||||
for( int k = 0; k < vecsize; k++ )
|
||||
{
|
||||
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k];
|
||||
float r = rptr0[k];
|
||||
s00 += w0*r; s10 += w1*r; s20 += w2*r;
|
||||
}
|
||||
|
||||
if( relu )
|
||||
{
|
||||
s00 = s00 > 0.f ? s00 : s00*r0;
|
||||
s10 = s10 > 0.f ? s10 : s10*r1;
|
||||
s20 = s20 > 0.f ? s20 : s20*r2;
|
||||
}
|
||||
|
||||
outptr0[j] = s00;
|
||||
outptr1[j] = s10;
|
||||
outptr2[j] = s20;
|
||||
}
|
||||
}
|
||||
_mm256_zeroupper();
|
||||
}
|
||||
|
||||
static inline void _mm256_load_deinterleave(const float* ptr, __m256& a, __m256& b)
|
||||
{
|
||||
__m256 t0 = _mm256_loadu_ps(ptr);
|
||||
@ -957,198 +654,6 @@ void fastGEMM1T( const float* vec, const float* weights,
|
||||
}
|
||||
}
|
||||
|
||||
enum { FASCONV_BASE_VECSZ = 8 };
|
||||
void fastConv( const float* weights, size_t wstep, const float* bias,
|
||||
const float* rowbuf, float* output, const int* outShape,
|
||||
int blockSize, int vecsize, int vecsize_aligned,
|
||||
const float* relu, bool initOutput )
|
||||
{
|
||||
const int vlm1 = vsetvlmax_e32m1();
|
||||
int outCn = outShape[1];
|
||||
size_t outPlaneSize = outShape[2]*outShape[3];
|
||||
// now compute dot product of the weights
|
||||
// and im2row-transformed part of the tensor
|
||||
for( int i = 0; i < outCn; i += 3 )
|
||||
{
|
||||
int unroll_tail = FASCONV_BASE_VECSZ;
|
||||
const float* wptr0 = weights + i*wstep;
|
||||
const float* wptr1 = wptr0 + wstep;
|
||||
const float* wptr2 = wptr1 + wstep;
|
||||
float* outptr0 = output + i*outPlaneSize;
|
||||
float* outptr1 = outptr0 + outPlaneSize;
|
||||
float* outptr2 = outptr1 + outPlaneSize;
|
||||
float bias0 = bias[i], bias1 = bias[i+1], bias2 = bias[i+2];
|
||||
|
||||
if( i+2 >= outCn )
|
||||
{
|
||||
wptr2 = wptr1;
|
||||
outptr2 = outptr1;
|
||||
bias2 = bias1;
|
||||
if( i+1 >= outCn )
|
||||
{
|
||||
wptr2 = wptr1 = wptr0;
|
||||
outptr2 = outptr1 = outptr0;
|
||||
bias2 = bias1 = bias0;
|
||||
}
|
||||
}
|
||||
|
||||
int j = 0;
|
||||
for( ; j < blockSize; j += FASCONV_BASE_VECSZ )
|
||||
{
|
||||
const float* rptr = rowbuf + j*vecsize_aligned;
|
||||
const float *rptr1 = rptr + vecsize_aligned*1,
|
||||
*rptr2 = rptr + vecsize_aligned*2,
|
||||
*rptr3 = rptr + vecsize_aligned*3,
|
||||
*rptr4 = rptr + vecsize_aligned*4,
|
||||
*rptr5 = rptr + vecsize_aligned*5,
|
||||
*rptr6 = rptr + vecsize_aligned*6,
|
||||
*rptr7 = rptr + vecsize_aligned*7;
|
||||
if (j + FASCONV_BASE_VECSZ > blockSize)
|
||||
{
|
||||
unroll_tail = blockSize - j;
|
||||
rptr1 = rptr + vecsize_aligned*std::min(1, unroll_tail-1),
|
||||
rptr2 = rptr + vecsize_aligned*std::min(2, unroll_tail-1),
|
||||
rptr3 = rptr + vecsize_aligned*std::min(3, unroll_tail-1),
|
||||
rptr4 = rptr + vecsize_aligned*std::min(4, unroll_tail-1),
|
||||
rptr5 = rptr + vecsize_aligned*std::min(5, unroll_tail-1),
|
||||
rptr6 = rptr + vecsize_aligned*std::min(6, unroll_tail-1),
|
||||
rptr7 = rptr + vecsize_aligned*std::min(7, unroll_tail-1);
|
||||
}
|
||||
|
||||
int vl, avl = vecsize;
|
||||
vfloat32m1_t
|
||||
vs00 = vfmv_v_f_f32m1(0, vlm1), vs10 = vfmv_v_f_f32m1(0, vlm1), vs20 = vfmv_v_f_f32m1(0, vlm1),
|
||||
vs01 = vfmv_v_f_f32m1(0, vlm1), vs11 = vfmv_v_f_f32m1(0, vlm1), vs21 = vfmv_v_f_f32m1(0, vlm1),
|
||||
vs02 = vfmv_v_f_f32m1(0, vlm1), vs12 = vfmv_v_f_f32m1(0, vlm1), vs22 = vfmv_v_f_f32m1(0, vlm1),
|
||||
vs03 = vfmv_v_f_f32m1(0, vlm1), vs13 = vfmv_v_f_f32m1(0, vlm1), vs23 = vfmv_v_f_f32m1(0, vlm1),
|
||||
vs04 = vfmv_v_f_f32m1(0, vlm1), vs14 = vfmv_v_f_f32m1(0, vlm1), vs24 = vfmv_v_f_f32m1(0, vlm1),
|
||||
vs05 = vfmv_v_f_f32m1(0, vlm1), vs15 = vfmv_v_f_f32m1(0, vlm1), vs25 = vfmv_v_f_f32m1(0, vlm1),
|
||||
vs06 = vfmv_v_f_f32m1(0, vlm1), vs16 = vfmv_v_f_f32m1(0, vlm1), vs26 = vfmv_v_f_f32m1(0, vlm1),
|
||||
vs07 = vfmv_v_f_f32m1(0, vlm1), vs17 = vfmv_v_f_f32m1(0, vlm1), vs27 = vfmv_v_f_f32m1(0, vlm1);
|
||||
|
||||
for (int k = 0; k < vecsize; k += vl, avl -= vl)
|
||||
{
|
||||
vl = vsetvl_e32m1(avl);
|
||||
vfloat32m1_t w0 = vle32_v_f32m1(wptr0 + k, vl);
|
||||
vfloat32m1_t w1 = vle32_v_f32m1(wptr1 + k, vl);
|
||||
vfloat32m1_t w2 = vle32_v_f32m1(wptr2 + k, vl);
|
||||
vfloat32m1_t r0 = vle32_v_f32m1(rptr, vl);
|
||||
|
||||
vs00 = vfmacc_vv_f32m1(vs00, w0, r0, vl);
|
||||
vs10 = vfmacc_vv_f32m1(vs10, w1, r0, vl);
|
||||
vs20 = vfmacc_vv_f32m1(vs20, w2, r0, vl);
|
||||
|
||||
r0 = vle32_v_f32m1(rptr1, vl);
|
||||
vs01 = vfmacc_vv_f32m1(vs01, w0, r0, vl);
|
||||
vs11 = vfmacc_vv_f32m1(vs11, w1, r0, vl);
|
||||
vs21 = vfmacc_vv_f32m1(vs21, w2, r0, vl);
|
||||
|
||||
r0 = vle32_v_f32m1(rptr2, vl);
|
||||
vs02 = vfmacc_vv_f32m1(vs02, w0, r0, vl);
|
||||
vs12 = vfmacc_vv_f32m1(vs12, w1, r0, vl);
|
||||
vs22 = vfmacc_vv_f32m1(vs22, w2, r0, vl);
|
||||
|
||||
r0 = vle32_v_f32m1(rptr3, vl);
|
||||
vs03 = vfmacc_vv_f32m1(vs03, w0, r0, vl);
|
||||
vs13 = vfmacc_vv_f32m1(vs13, w1, r0, vl);
|
||||
vs23 = vfmacc_vv_f32m1(vs23, w2, r0, vl);
|
||||
|
||||
r0 = vle32_v_f32m1(rptr4, vl);
|
||||
vs04 = vfmacc_vv_f32m1(vs04, w0, r0, vl);
|
||||
vs14 = vfmacc_vv_f32m1(vs14, w1, r0, vl);
|
||||
vs24 = vfmacc_vv_f32m1(vs24, w2, r0, vl);
|
||||
|
||||
r0 = vle32_v_f32m1(rptr5, vl);
|
||||
vs05 = vfmacc_vv_f32m1(vs05, w0, r0, vl);
|
||||
vs15 = vfmacc_vv_f32m1(vs15, w1, r0, vl);
|
||||
vs25 = vfmacc_vv_f32m1(vs25, w2, r0, vl);
|
||||
|
||||
r0 = vle32_v_f32m1(rptr6, vl);
|
||||
vs06 = vfmacc_vv_f32m1(vs06, w0, r0, vl);
|
||||
vs16 = vfmacc_vv_f32m1(vs16, w1, r0, vl);
|
||||
vs26 = vfmacc_vv_f32m1(vs26, w2, r0, vl);
|
||||
|
||||
r0 = vle32_v_f32m1(rptr7, vl);
|
||||
vs07 = vfmacc_vv_f32m1(vs07, w0, r0, vl);
|
||||
vs17 = vfmacc_vv_f32m1(vs17, w1, r0, vl);
|
||||
vs27 = vfmacc_vv_f32m1(vs27, w2, r0, vl);
|
||||
|
||||
rptr += vl; rptr1 += vl; rptr2 += vl; rptr3 += vl;
|
||||
rptr4 += vl; rptr5 += vl; rptr6 += vl; rptr7 += vl;
|
||||
}
|
||||
|
||||
// compute sum of each vs
|
||||
vfloat32m1_t zero = vfmv_v_f_f32m1(0, vlm1);
|
||||
// unroll_tail(vl) is required here to be at least FASCONV_BASE_VECSZ, aka 8.
|
||||
float sum0[FASCONV_BASE_VECSZ], sum1[FASCONV_BASE_VECSZ], sum2[FASCONV_BASE_VECSZ];
|
||||
sum0[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs00, zero, vlm1));
|
||||
sum0[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs01, zero, vlm1));
|
||||
sum0[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs02, zero, vlm1));
|
||||
sum0[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs03, zero, vlm1));
|
||||
sum0[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs04, zero, vlm1));
|
||||
sum0[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs05, zero, vlm1));
|
||||
sum0[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs06, zero, vlm1));
|
||||
sum0[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs07, zero, vlm1));
|
||||
sum1[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs10, zero, vlm1));
|
||||
sum1[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs11, zero, vlm1));
|
||||
sum1[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs12, zero, vlm1));
|
||||
sum1[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs13, zero, vlm1));
|
||||
sum1[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs14, zero, vlm1));
|
||||
sum1[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs15, zero, vlm1));
|
||||
sum1[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs16, zero, vlm1));
|
||||
sum1[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs17, zero, vlm1));
|
||||
sum2[0] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs20, zero, vlm1));
|
||||
sum2[1] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs21, zero, vlm1));
|
||||
sum2[2] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs22, zero, vlm1));
|
||||
sum2[3] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs23, zero, vlm1));
|
||||
sum2[4] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs24, zero, vlm1));
|
||||
sum2[5] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs25, zero, vlm1));
|
||||
sum2[6] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs26, zero, vlm1));
|
||||
sum2[7] = vfmv_f_s_f32m1_f32(vfredosum_vs_f32m1_f32m1(zero, vs27, zero, vlm1));
|
||||
|
||||
// if VLEN = 128, so LMUL = 2 for unroll_tail(vl) = 8.
|
||||
// otherwise, VLEN >=256, we only use fist 8 element of the vReg.
|
||||
vfloat32m2_t s0, s1, s2;
|
||||
if( initOutput )
|
||||
{
|
||||
s0 = vfmv_v_f_f32m2(bias0, unroll_tail);
|
||||
s1 = vfmv_v_f_f32m2(bias1, unroll_tail);
|
||||
s2 = vfmv_v_f_f32m2(bias2, unroll_tail);
|
||||
}
|
||||
else
|
||||
{
|
||||
s0 = vle32_v_f32m2(outptr0 + j, unroll_tail);
|
||||
s1 = vle32_v_f32m2(outptr1 + j, unroll_tail);
|
||||
s2 = vle32_v_f32m2(outptr2 + j, unroll_tail);
|
||||
}
|
||||
s0 = vfadd_vv_f32m2(vle32_v_f32m2(sum0, unroll_tail), s0, unroll_tail);
|
||||
s1 = vfadd_vv_f32m2(vle32_v_f32m2(sum1, unroll_tail), s1, unroll_tail);
|
||||
s2 = vfadd_vv_f32m2(vle32_v_f32m2(sum2, unroll_tail), s2, unroll_tail);
|
||||
|
||||
if( relu )
|
||||
{
|
||||
float r0 = relu[i], r1 = relu[i+1], r2 = relu[i+2];
|
||||
if( i+2 >= outCn )
|
||||
{
|
||||
r2 = r1;
|
||||
if( i+1 >= outCn )
|
||||
r2 = r1 = r0;
|
||||
}
|
||||
vbool16_t m0 = vmfgt_vf_f32m2_b16(s0, 0, unroll_tail);
|
||||
vbool16_t m1 = vmfgt_vf_f32m2_b16(s1, 0, unroll_tail);
|
||||
vbool16_t m2 = vmfgt_vf_f32m2_b16(s2, 0, unroll_tail);
|
||||
s0 = vmerge_vvm_f32m2(m0, vfmul_vf_f32m2(s0, r0, unroll_tail), s0, unroll_tail);
|
||||
s1 = vmerge_vvm_f32m2(m1, vfmul_vf_f32m2(s1, r1, unroll_tail), s1, unroll_tail);
|
||||
s2 = vmerge_vvm_f32m2(m2, vfmul_vf_f32m2(s2, r2, unroll_tail), s2, unroll_tail);
|
||||
}
|
||||
|
||||
vse32_v_f32m2(outptr0 + j, s0, unroll_tail);
|
||||
vse32_v_f32m2(outptr1 + j, s1, unroll_tail);
|
||||
vse32_v_f32m2(outptr2 + j, s2, unroll_tail);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Example for load_deinterleave:
|
||||
input: ptr[16] = {1,2,3, ... ,14,15,16}
|
||||
@ -1345,317 +850,6 @@ void fastDepthwiseConv( const float* wptr,
|
||||
|
||||
#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_LASX
|
||||
|
||||
enum { FASCONV_BASE_VECSZ = 4 };
|
||||
|
||||
void fastConv( const float* weights, size_t wstep, const float* bias,
|
||||
const float* rowbuf, float* output, const int* outShape,
|
||||
int blockSize, int vecsize, int vecsize_aligned,
|
||||
const float* relu, bool initOutput )
|
||||
{
|
||||
int outCn = outShape[1];
|
||||
size_t outPlaneSize = outShape[2]*outShape[3];
|
||||
float r0 = 1.f, r1 = 1.f, r2 = 1.f;
|
||||
__m256 t1 = _v256_setall_ps(1.f), t2 = _v256_setall_ps(0.f);
|
||||
__m128 vr0 = *(__m128*)&t1, vr1 = vr0, vr2 = vr0, z = *(__m128*)&t2;
|
||||
int CV_DECL_ALIGNED(16) maskbuf[FASCONV_BASE_VECSZ] = {0};
|
||||
int rsz = blockSize % FASCONV_BASE_VECSZ;
|
||||
for( int i = 0; i < rsz; i++ )
|
||||
maskbuf[FASCONV_BASE_VECSZ - i - 1] = -1;
|
||||
__m128i mask = __lsx_vld((const float*)maskbuf, 0);
|
||||
|
||||
// now compute dot product of the weights
|
||||
// and im2row-transformed part of the tensor
|
||||
for( int i = 0; i < outCn; i += 3 )
|
||||
{
|
||||
const float* wptr0 = weights + i*wstep;
|
||||
const float* wptr1 = wptr0 + wstep;
|
||||
const float* wptr2 = wptr1 + wstep;
|
||||
float* outptr0 = output + i*outPlaneSize;
|
||||
float* outptr1 = outptr0 + outPlaneSize;
|
||||
float* outptr2 = outptr1 + outPlaneSize;
|
||||
float bias0 = bias[i], bias1 = bias[i+1], bias2 = bias[i+2];
|
||||
|
||||
if( i+2 >= outCn )
|
||||
{
|
||||
wptr2 = wptr1;
|
||||
outptr2 = outptr1;
|
||||
bias2 = bias1;
|
||||
if( i+1 >= outCn )
|
||||
{
|
||||
wptr2 = wptr1 = wptr0;
|
||||
outptr2 = outptr1 = outptr0;
|
||||
bias2 = bias1 = bias0;
|
||||
}
|
||||
}
|
||||
|
||||
if( relu )
|
||||
{
|
||||
r0 = relu[i]; r1 = relu[i+1]; r2 = relu[i+2];
|
||||
if( i+2 >= outCn )
|
||||
{
|
||||
r2 = r1;
|
||||
if( i+1 >= outCn )
|
||||
r2 = r1 = r0;
|
||||
}
|
||||
vr0 = _v256_extract_low(_v256_setall_ps(r0));
|
||||
vr1 = _v256_extract_low(_v256_setall_ps(r1));
|
||||
vr2 = _v256_extract_low(_v256_setall_ps(r2));
|
||||
}
|
||||
|
||||
int j = 0;
|
||||
for( ; j < blockSize; j += FASCONV_BASE_VECSZ )
|
||||
{
|
||||
bool tail = false;
|
||||
if (j + FASCONV_BASE_VECSZ > blockSize)
|
||||
{
|
||||
if (j == 0)
|
||||
break;
|
||||
j = blockSize - FASCONV_BASE_VECSZ;
|
||||
tail = true;
|
||||
}
|
||||
int k = 0;
|
||||
const float* rptr = rowbuf + j*vecsize_aligned;
|
||||
|
||||
__m256i tmp;
|
||||
__m256 vs00 = (__m256)__lasx_xvxor_v(tmp, tmp), vs01 = (__m256)__lasx_xvxor_v(tmp, tmp),
|
||||
vs02 = (__m256)__lasx_xvxor_v(tmp, tmp), vs03 = (__m256)__lasx_xvxor_v(tmp, tmp),
|
||||
vs10 = (__m256)__lasx_xvxor_v(tmp, tmp), vs11 = (__m256)__lasx_xvxor_v(tmp, tmp),
|
||||
vs12 = (__m256)__lasx_xvxor_v(tmp, tmp), vs13 = (__m256)__lasx_xvxor_v(tmp, tmp),
|
||||
vs20 = (__m256)__lasx_xvxor_v(tmp, tmp), vs21 = (__m256)__lasx_xvxor_v(tmp, tmp),
|
||||
vs22 = (__m256)__lasx_xvxor_v(tmp, tmp), vs23 = (__m256)__lasx_xvxor_v(tmp, tmp);
|
||||
|
||||
for (; k < vecsize; k += 8, rptr += 8 )
|
||||
{
|
||||
__m256 w0 = (__m256)__lasx_xvld(wptr0 + k, 0);
|
||||
__m256 w1 = (__m256)__lasx_xvld(wptr1 + k, 0);
|
||||
__m256 w2 = (__m256)__lasx_xvld(wptr2 + k, 0);
|
||||
__m256 r0 = (__m256)__lasx_xvld(rptr, 0);
|
||||
|
||||
vs00 = __lasx_xvfmadd_s(w0, r0, vs00);
|
||||
vs10 = __lasx_xvfmadd_s(w1, r0, vs10);
|
||||
vs20 = __lasx_xvfmadd_s(w2, r0, vs20);
|
||||
|
||||
r0 = (__m256)__lasx_xvld(rptr + vecsize_aligned, 0);
|
||||
vs01 = __lasx_xvfmadd_s(w0, r0, vs01);
|
||||
vs11 = __lasx_xvfmadd_s(w1, r0, vs11);
|
||||
vs21 = __lasx_xvfmadd_s(w2, r0, vs21);
|
||||
|
||||
r0 = (__m256)__lasx_xvld(rptr + vecsize_aligned*2, 0);
|
||||
vs02 = __lasx_xvfmadd_s(w0, r0, vs02);
|
||||
vs12 = __lasx_xvfmadd_s(w1, r0, vs12);
|
||||
vs22 = __lasx_xvfmadd_s(w2, r0, vs22);
|
||||
|
||||
r0 = (__m256)__lasx_xvld(rptr + vecsize_aligned*3, 0);
|
||||
vs03 = __lasx_xvfmadd_s(w0, r0, vs03);
|
||||
vs13 = __lasx_xvfmadd_s(w1, r0, vs13);
|
||||
vs23 = __lasx_xvfmadd_s(w2, r0, vs23);
|
||||
}
|
||||
|
||||
/*t0*/
|
||||
__m256 vs00_perm = (__m256)__lasx_xvpermi_d(vs00, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs00_add_2w = __lasx_xvfadd_s(vs00, vs00_perm);
|
||||
__m256 tmp00_srl = (__m256)__lasx_xvsrli_d(vs00_add_2w, 32);
|
||||
__m256 vs00_add_4w = __lasx_xvfadd_s(vs00_add_2w, tmp00_srl);
|
||||
|
||||
__m256 vs01_perm = (__m256)__lasx_xvpermi_d(vs01, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs01_add_2w = __lasx_xvfadd_s(vs01, vs01_perm);
|
||||
__m256 tmp01_srl = (__m256)__lasx_xvsrli_d(vs01_add_2w, 32);
|
||||
__m256 vs01_add_4w = __lasx_xvfadd_s(vs01_add_2w, tmp01_srl);
|
||||
|
||||
__m256 vs02_perm = (__m256)__lasx_xvpermi_d(vs02, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs02_add_2w = __lasx_xvfadd_s(vs02, vs02_perm);
|
||||
__m256 tmp02_srl = (__m256)__lasx_xvsrli_d(vs02_add_2w, 32);
|
||||
__m256 vs02_add_4w = __lasx_xvfadd_s(vs02_add_2w, tmp02_srl);
|
||||
|
||||
__m256 vs03_perm = (__m256)__lasx_xvpermi_d(vs03, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs03_add_2w = __lasx_xvfadd_s(vs03, vs03_perm);
|
||||
__m256 tmp03_srl = (__m256)__lasx_xvsrli_d(vs03_add_2w, 32);
|
||||
__m256 vs03_add_4w = __lasx_xvfadd_s(vs03_add_2w, tmp03_srl);
|
||||
|
||||
__m256i vs01_vs00 = __lasx_xvpackev_w((__m256i)vs01_add_4w, (__m256i)vs00_add_4w);
|
||||
__m256i vs03_vs02 = __lasx_xvpackev_w((__m256i)vs03_add_4w, (__m256i)vs02_add_4w);
|
||||
__m256 t0 = (__m256)__lasx_xvpackev_d(vs03_vs02, vs01_vs00);
|
||||
|
||||
/*t1*/
|
||||
__m256 vs10_perm = (__m256)__lasx_xvpermi_d(vs10, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs10_add_2w = __lasx_xvfadd_s(vs10, vs10_perm);
|
||||
__m256 tmp10_srl = (__m256)__lasx_xvsrli_d(vs10_add_2w, 32);
|
||||
__m256 vs10_add_4w = __lasx_xvfadd_s(vs10_add_2w, tmp10_srl);
|
||||
|
||||
__m256 vs11_perm = (__m256)__lasx_xvpermi_d(vs11, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs11_add_2w = __lasx_xvfadd_s(vs11, vs11_perm);
|
||||
__m256 tmp11_srl = (__m256)__lasx_xvsrli_d(vs11_add_2w, 32);
|
||||
__m256 vs11_add_4w = __lasx_xvfadd_s(vs11_add_2w, tmp11_srl);
|
||||
|
||||
__m256 vs12_perm = (__m256)__lasx_xvpermi_d(vs12, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs12_add_2w = __lasx_xvfadd_s(vs12, vs12_perm);
|
||||
__m256 tmp12_srl = (__m256)__lasx_xvsrli_d(vs12_add_2w, 32);
|
||||
__m256 vs12_add_4w = __lasx_xvfadd_s(vs12_add_2w, tmp12_srl);
|
||||
|
||||
__m256 vs13_perm = (__m256)__lasx_xvpermi_d(vs13, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs13_add_2w = __lasx_xvfadd_s(vs13, vs13_perm);
|
||||
__m256 tmp13_srl = (__m256)__lasx_xvsrli_d(vs13_add_2w, 32);
|
||||
__m256 vs13_add_4w = __lasx_xvfadd_s(vs13_add_2w, tmp13_srl);
|
||||
|
||||
__m256i vs11_vs10 = __lasx_xvpackev_w((__m256i)vs11_add_4w, (__m256i)vs10_add_4w);
|
||||
__m256i vs13_vs12 = __lasx_xvpackev_w((__m256i)vs13_add_4w, (__m256i)vs12_add_4w);
|
||||
__m256 t1 = (__m256)__lasx_xvpackev_d(vs13_vs12, vs11_vs10);
|
||||
|
||||
/*t2*/
|
||||
__m256 vs20_perm = (__m256)__lasx_xvpermi_d(vs20, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs20_add_2w = __lasx_xvfadd_s(vs20, vs20_perm);
|
||||
__m256 tmp20_srl = (__m256)__lasx_xvsrli_d(vs20_add_2w, 32);
|
||||
__m256 vs20_add_4w = __lasx_xvfadd_s(vs20_add_2w, tmp20_srl);
|
||||
|
||||
__m256 vs21_perm = (__m256)__lasx_xvpermi_d(vs21, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs21_add_2w = __lasx_xvfadd_s(vs21, vs21_perm);
|
||||
__m256 tmp21_srl = (__m256)__lasx_xvsrli_d(vs21_add_2w, 32);
|
||||
__m256 vs21_add_4w = __lasx_xvfadd_s(vs21_add_2w, tmp21_srl);
|
||||
|
||||
__m256 vs22_perm = (__m256)__lasx_xvpermi_d(vs22, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs22_add_2w = __lasx_xvfadd_s(vs22, vs22_perm);
|
||||
__m256 tmp22_srl = (__m256)__lasx_xvsrli_d(vs22_add_2w, 32);
|
||||
__m256 vs22_add_4w = __lasx_xvfadd_s(vs22_add_2w, tmp22_srl);
|
||||
|
||||
__m256 vs23_perm = (__m256)__lasx_xvpermi_d(vs23, (2<<6) + (3<<4) + (0<<2) + 1);
|
||||
__m256 vs23_add_2w = __lasx_xvfadd_s(vs23, vs23_perm);
|
||||
__m256 tmp23_srl = (__m256)__lasx_xvsrli_d(vs23_add_2w, 32);
|
||||
__m256 vs23_add_4w = __lasx_xvfadd_s(vs23_add_2w, tmp23_srl);
|
||||
|
||||
__m256i vs21_vs20 = __lasx_xvpackev_w((__m256i)vs21_add_4w, (__m256i)vs20_add_4w);
|
||||
__m256i vs23_vs22 = __lasx_xvpackev_w((__m256i)vs23_add_4w, (__m256i)vs22_add_4w);
|
||||
__m256 t2 = (__m256)__lasx_xvpackev_d(vs23_vs22, vs21_vs20);
|
||||
|
||||
t0 = __lasx_xvfadd_s(t0, (__m256)__lasx_xvpermi_q(t0, t0, 1));
|
||||
t1 = __lasx_xvfadd_s(t1, (__m256)__lasx_xvpermi_q(t1, t1, 1));
|
||||
t2 = __lasx_xvfadd_s(t2, (__m256)__lasx_xvpermi_q(t2, t2, 1));
|
||||
|
||||
__m128 s0, s1, s2;
|
||||
|
||||
if( initOutput )
|
||||
{
|
||||
s0 = _v256_extract_low(_v256_setall_ps(bias0));
|
||||
s1 = _v256_extract_low(_v256_setall_ps(bias1));
|
||||
s2 = _v256_extract_low(_v256_setall_ps(bias2));
|
||||
}
|
||||
else
|
||||
{
|
||||
s0 = (__m128)__lsx_vld(outptr0 + j, 0);
|
||||
s1 = (__m128)__lsx_vld(outptr1 + j, 0);
|
||||
s2 = (__m128)__lsx_vld(outptr2 + j, 0);
|
||||
}
|
||||
|
||||
s0 = __lsx_vfadd_s(s0, *(__m128*)&t0);
|
||||
s1 = __lsx_vfadd_s(s1, *(__m128*)&t1);
|
||||
s2 = __lsx_vfadd_s(s2, *(__m128*)&t2);
|
||||
|
||||
if( relu )
|
||||
{
|
||||
__m128i m0 = __lsx_vfcmp_clt_s(z, s0);
|
||||
__m128i m1 = __lsx_vfcmp_clt_s(z, s1);
|
||||
__m128i m2 = __lsx_vfcmp_clt_s(z, s2);
|
||||
s0 = (__m128)__lsx_vbitsel_v((__m128i)__lsx_vfmul_s(s0, vr0), (__m128i)s0, m0);
|
||||
s1 = (__m128)__lsx_vbitsel_v((__m128i)__lsx_vfmul_s(s1, vr1), (__m128i)s1, m1);
|
||||
s2 = (__m128)__lsx_vbitsel_v((__m128i)__lsx_vfmul_s(s2, vr2), (__m128i)s2, m2);
|
||||
}
|
||||
|
||||
if( tail )
|
||||
{
|
||||
s0 = (__m128)__lsx_vbitsel_v(__lsx_vld(outptr0 + j, 0), (__m128i)s0, mask);
|
||||
s1 = (__m128)__lsx_vbitsel_v(__lsx_vld(outptr1 + j, 0), (__m128i)s1, mask);
|
||||
s2 = (__m128)__lsx_vbitsel_v(__lsx_vld(outptr2 + j, 0), (__m128i)s2, mask);
|
||||
}
|
||||
|
||||
__lsx_vst(s0, outptr0 + j, 0);
|
||||
__lsx_vst(s1, outptr1 + j, 0);
|
||||
__lsx_vst(s2, outptr2 + j, 0);
|
||||
}
|
||||
|
||||
for( ; j <= blockSize - 2; j += 2 )
|
||||
{
|
||||
const float* rptr0 = rowbuf + j*vecsize_aligned;
|
||||
const float* rptr1 = rowbuf + (j+1)*vecsize_aligned;
|
||||
float s00, s01, s10, s11, s20, s21;
|
||||
|
||||
if( initOutput )
|
||||
{
|
||||
s00 = s01 = bias0;
|
||||
s10 = s11 = bias1;
|
||||
s20 = s21 = bias2;
|
||||
}
|
||||
else
|
||||
{
|
||||
s00 = outptr0[j]; s01 = outptr0[j+1];
|
||||
s10 = outptr1[j]; s11 = outptr1[j+1];
|
||||
s20 = outptr2[j]; s21 = outptr2[j+1];
|
||||
}
|
||||
|
||||
for( int k = 0; k < vecsize; k++ )
|
||||
{
|
||||
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k];
|
||||
float r = rptr0[k];
|
||||
s00 += w0*r; s10 += w1*r; s20 += w2*r;
|
||||
r = rptr1[k];
|
||||
s01 += w0*r; s11 += w1*r; s21 += w2*r;
|
||||
}
|
||||
|
||||
if( relu )
|
||||
{
|
||||
s00 = s00 > 0.f ? s00 : s00*r0;
|
||||
s01 = s01 > 0.f ? s01 : s01*r0;
|
||||
s10 = s10 > 0.f ? s10 : s10*r1;
|
||||
s11 = s11 > 0.f ? s11 : s11*r1;
|
||||
s20 = s20 > 0.f ? s20 : s20*r2;
|
||||
s21 = s21 > 0.f ? s21 : s21*r2;
|
||||
}
|
||||
|
||||
outptr0[j] = s00;
|
||||
outptr0[j+1] = s01;
|
||||
outptr1[j] = s10;
|
||||
outptr1[j+1] = s11;
|
||||
outptr2[j] = s20;
|
||||
outptr2[j+1] = s21;
|
||||
}
|
||||
|
||||
for( ; j < blockSize; j++ )
|
||||
{
|
||||
const float* rptr0 = rowbuf + j*vecsize_aligned;
|
||||
float s00, s10, s20;
|
||||
|
||||
if( initOutput )
|
||||
{
|
||||
s00 = bias0;
|
||||
s10 = bias1;
|
||||
s20 = bias2;
|
||||
}
|
||||
else
|
||||
{
|
||||
s00 = outptr0[j];
|
||||
s10 = outptr1[j];
|
||||
s20 = outptr2[j];
|
||||
}
|
||||
|
||||
for( int k = 0; k < vecsize; k++ )
|
||||
{
|
||||
float w0 = wptr0[k], w1 = wptr1[k], w2 = wptr2[k];
|
||||
float r = rptr0[k];
|
||||
s00 += w0*r; s10 += w1*r; s20 += w2*r;
|
||||
}
|
||||
|
||||
if( relu )
|
||||
{
|
||||
s00 = s00 > 0.f ? s00 : s00*r0;
|
||||
s10 = s10 > 0.f ? s10 : s10*r1;
|
||||
s20 = s20 > 0.f ? s20 : s20*r2;
|
||||
}
|
||||
|
||||
outptr0[j] = s00;
|
||||
outptr1[j] = s10;
|
||||
outptr2[j] = s20;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void _v256_load_deinterleave(const float* ptr, __m256& a, __m256& b)
|
||||
{
|
||||
__m256 t0 = (__m256)__lasx_xvld(ptr, 0);
|
||||
|
Loading…
Reference in New Issue
Block a user