mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Merge pull request #13762 from alalek:ocl_dnn_mvn_local_size
This commit is contained in:
commit
a63f66c90e
@ -138,9 +138,12 @@ public:
|
|||||||
UMat& bnorm_weight = umat_scale;
|
UMat& bnorm_weight = umat_scale;
|
||||||
UMat& bnorm_bias = umat_shift;
|
UMat& bnorm_bias = umat_shift;
|
||||||
|
|
||||||
|
const unsigned LOCAL_SIZE = 128;
|
||||||
bool use_half = (inputs[0].depth() == CV_16S);
|
bool use_half = (inputs[0].depth() == CV_16S);
|
||||||
String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s", use_half ? "half" : "float",
|
String opts = format(" -DT=%s -DT4=%s -Dconvert_T=%s -DLOCAL_SIZE=%u", use_half ? "half" : "float",
|
||||||
use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4");
|
use_half ? "half4" : "float4", use_half ? "convert_half4" : "convert_float4",
|
||||||
|
LOCAL_SIZE
|
||||||
|
);
|
||||||
|
|
||||||
int splitDim = (acrossChannels) ? 1 : 2;
|
int splitDim = (acrossChannels) ? 1 : 2;
|
||||||
for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
|
for (size_t inpIdx = 0; inpIdx < inputs.size(); inpIdx++)
|
||||||
@ -155,8 +158,8 @@ public:
|
|||||||
float alpha = 1.0f / s[1];
|
float alpha = 1.0f / s[1];
|
||||||
|
|
||||||
String buildopt = "-DNUM=4" + opts;
|
String buildopt = "-DNUM=4" + opts;
|
||||||
ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt);
|
ocl::Kernel k("mean_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN_FUSE");
|
||||||
size_t localsize[] = { 128 };
|
size_t localsize[] = { LOCAL_SIZE };
|
||||||
size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] };
|
size_t globalsize[] = { (size_t)s[0] / 4 * localsize[0] };
|
||||||
|
|
||||||
int argId = 0;
|
int argId = 0;
|
||||||
@ -165,7 +168,6 @@ public:
|
|||||||
k.set(argId++, alpha);
|
k.set(argId++, alpha);
|
||||||
k.set(argId++, ocl::KernelArg::PtrWriteOnly(meanMat));
|
k.set(argId++, ocl::KernelArg::PtrWriteOnly(meanMat));
|
||||||
k.set(argId++, ocl::KernelArg::PtrWriteOnly(tmpMat));
|
k.set(argId++, ocl::KernelArg::PtrWriteOnly(tmpMat));
|
||||||
k.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
|
|
||||||
bool ret = k.run(1, globalsize, localsize, false);
|
bool ret = k.run(1, globalsize, localsize, false);
|
||||||
if (!ret)
|
if (!ret)
|
||||||
return false;
|
return false;
|
||||||
@ -173,7 +175,7 @@ public:
|
|||||||
buildopt += format(" %s %s", (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "",
|
buildopt += format(" %s %s", (fuse_batch_norm) ? "-DFUSE_BATCH_NORM" : "",
|
||||||
(fuse_relu) ? "-DFUSE_RELU" : "");
|
(fuse_relu) ? "-DFUSE_RELU" : "");
|
||||||
|
|
||||||
ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt);
|
ocl::Kernel k1("mvn_fuse4", ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MVN_FUSE");
|
||||||
argId = 0;
|
argId = 0;
|
||||||
k1.set(argId++, ocl::KernelArg::PtrReadOnly(tmpMat));
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(tmpMat));
|
||||||
k1.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat));
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(inpMat));
|
||||||
@ -185,7 +187,6 @@ public:
|
|||||||
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_weight));
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_weight));
|
||||||
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_bias));
|
k1.set(argId++, ocl::KernelArg::PtrReadOnly(bnorm_bias));
|
||||||
k1.set(argId++, ocl::KernelArg::PtrWriteOnly(outMat));
|
k1.set(argId++, ocl::KernelArg::PtrWriteOnly(outMat));
|
||||||
k1.set(argId++, NULL, localsize[0] * sizeof(cl_float4));
|
|
||||||
ret = k1.run(1, globalsize, localsize, false);
|
ret = k1.run(1, globalsize, localsize, false);
|
||||||
if (!ret)
|
if (!ret)
|
||||||
return false;
|
return false;
|
||||||
@ -243,7 +244,7 @@ public:
|
|||||||
if (normVariance)
|
if (normVariance)
|
||||||
{
|
{
|
||||||
String kname = format("calc_mean%d", number);
|
String kname = format("calc_mean%d", number);
|
||||||
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
|
ocl::Kernel kernel(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt + " -DKERNEL_MEAN");
|
||||||
if (kernel.empty())
|
if (kernel.empty())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -263,7 +264,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
String kname = format("mvn%d", number);
|
String kname = format("mvn%d", number);
|
||||||
buildopt += format("%s%s%s", (normVariance) ? " -DNORM_VARIANCE" : "",
|
buildopt += format("%s%s%s -DKERNEL_MVN", (normVariance) ? " -DNORM_VARIANCE" : "",
|
||||||
(fuse_batch_norm) ? " -DFUSE_BATCH_NORM" : "",
|
(fuse_batch_norm) ? " -DFUSE_BATCH_NORM" : "",
|
||||||
(fuse_relu) ? " -DFUSE_RELU" : "");
|
(fuse_relu) ? " -DFUSE_RELU" : "");
|
||||||
ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
|
ocl::Kernel kernel1(kname.c_str(), ocl::dnn::mvn_oclsrc, buildopt);
|
||||||
|
@ -74,6 +74,8 @@
|
|||||||
#define MVN_FUSE mvn_fuse1
|
#define MVN_FUSE mvn_fuse1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef KERNEL_MEAN
|
||||||
|
|
||||||
__kernel void CALC_MEAN(__global const Dtype* src,
|
__kernel void CALC_MEAN(__global const Dtype* src,
|
||||||
const int rows,
|
const int rows,
|
||||||
const int cols,
|
const int cols,
|
||||||
@ -94,6 +96,8 @@ __kernel void CALC_MEAN(__global const Dtype* src,
|
|||||||
store(dst_vec, dst, index);
|
store(dst_vec, dst, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#elif defined KERNEL_MVN
|
||||||
|
|
||||||
__kernel void MVN(__global const Dtype* src,
|
__kernel void MVN(__global const Dtype* src,
|
||||||
const int rows,
|
const int rows,
|
||||||
const int cols,
|
const int cols,
|
||||||
@ -140,12 +144,13 @@ __kernel void MVN(__global const Dtype* src,
|
|||||||
store(dst_vec, dst, index);
|
store(dst_vec, dst, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#elif defined KERNEL_MEAN_FUSE
|
||||||
|
|
||||||
__kernel void MEAN_FUSE(__global const T * A,
|
__kernel void MEAN_FUSE(__global const T * A,
|
||||||
unsigned int A_col_size,
|
unsigned int A_col_size,
|
||||||
float alpha,
|
float alpha,
|
||||||
__global T4 * mean,
|
__global T4 * mean,
|
||||||
__global Dtype * tmp,
|
__global Dtype * tmp)
|
||||||
__local Dtype4 * work)
|
|
||||||
{
|
{
|
||||||
unsigned int row_gid = get_group_id(0);
|
unsigned int row_gid = get_group_id(0);
|
||||||
unsigned int lid = get_local_id(0);
|
unsigned int lid = get_local_id(0);
|
||||||
@ -168,15 +173,16 @@ __kernel void MEAN_FUSE(__global const T * A,
|
|||||||
dot2 += convert_float4(a2);
|
dot2 += convert_float4(a2);
|
||||||
dot3 += convert_float4(a3);
|
dot3 += convert_float4(a3);
|
||||||
|
|
||||||
i += get_local_size(0);
|
i += LOCAL_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__local Dtype4 work[LOCAL_SIZE];
|
||||||
work[lid].s0 = dot(dot0, b0);
|
work[lid].s0 = dot(dot0, b0);
|
||||||
work[lid].s1 = dot(dot1, b0);
|
work[lid].s1 = dot(dot1, b0);
|
||||||
work[lid].s2 = dot(dot2, b0);
|
work[lid].s2 = dot(dot2, b0);
|
||||||
work[lid].s3 = dot(dot3, b0);
|
work[lid].s3 = dot(dot3, b0);
|
||||||
|
|
||||||
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1)
|
for(unsigned int stride=LOCAL_SIZE/2 ; stride>0 ; stride>>=1)
|
||||||
{
|
{
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
if(lid < stride)
|
if(lid < stride)
|
||||||
@ -212,10 +218,12 @@ __kernel void MEAN_FUSE(__global const T * A,
|
|||||||
vstore4(dot2, i, dst0_read + 2 * A_col_size);
|
vstore4(dot2, i, dst0_read + 2 * A_col_size);
|
||||||
vstore4(dot3, i, dst0_read + 3 * A_col_size);
|
vstore4(dot3, i, dst0_read + 3 * A_col_size);
|
||||||
|
|
||||||
i += get_local_size(0);
|
i += LOCAL_SIZE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#elif defined KERNEL_MVN_FUSE
|
||||||
|
|
||||||
__kernel void MVN_FUSE(__global const Dtype * tmp,
|
__kernel void MVN_FUSE(__global const Dtype * tmp,
|
||||||
__global const T * A,
|
__global const T * A,
|
||||||
__global const T4 * mean,
|
__global const T4 * mean,
|
||||||
@ -225,8 +233,7 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
|
|||||||
const float relu_slope,
|
const float relu_slope,
|
||||||
__global const Dtype4 * bnorm_weight,
|
__global const Dtype4 * bnorm_weight,
|
||||||
__global const Dtype4 * bnorm_bias,
|
__global const Dtype4 * bnorm_bias,
|
||||||
__global T * B,
|
__global T * B)
|
||||||
__local Dtype4 * work)
|
|
||||||
{
|
{
|
||||||
unsigned int row_gid = get_group_id(0);
|
unsigned int row_gid = get_group_id(0);
|
||||||
unsigned int lid = get_local_id(0);
|
unsigned int lid = get_local_id(0);
|
||||||
@ -250,15 +257,16 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
|
|||||||
dot2 += a2;
|
dot2 += a2;
|
||||||
dot3 += a3;
|
dot3 += a3;
|
||||||
|
|
||||||
i += get_local_size(0);
|
i += LOCAL_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__local Dtype4 work[LOCAL_SIZE];
|
||||||
work[lid].s0 = dot(dot0, b0);
|
work[lid].s0 = dot(dot0, b0);
|
||||||
work[lid].s1 = dot(dot1, b0);
|
work[lid].s1 = dot(dot1, b0);
|
||||||
work[lid].s2 = dot(dot2, b0);
|
work[lid].s2 = dot(dot2, b0);
|
||||||
work[lid].s3 = dot(dot3, b0);
|
work[lid].s3 = dot(dot3, b0);
|
||||||
|
|
||||||
for(unsigned int stride=get_local_size(0)/2 ; stride>0 ; stride>>=1)
|
for(unsigned int stride=LOCAL_SIZE/2 ; stride>0 ; stride>>=1)
|
||||||
{
|
{
|
||||||
barrier(CLK_LOCAL_MEM_FENCE);
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
if(lid < stride)
|
if(lid < stride)
|
||||||
@ -314,6 +322,10 @@ __kernel void MVN_FUSE(__global const Dtype * tmp,
|
|||||||
vstore4(convert_T(dot2), i, dst0_read + 2 * A_col_size);
|
vstore4(convert_T(dot2), i, dst0_read + 2 * A_col_size);
|
||||||
vstore4(convert_T(dot3), i, dst0_read + 3 * A_col_size);
|
vstore4(convert_T(dot3), i, dst0_read + 3 * A_col_size);
|
||||||
|
|
||||||
i += get_local_size(0);
|
i += LOCAL_SIZE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
#error "Configuration error!"
|
||||||
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user