mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Merge pull request #10795 from pengli:dnn
This commit is contained in:
commit
398ebbac98
@ -1386,8 +1386,11 @@ struct Net::Impl
|
||||
|
||||
if ( preferableTarget == DNN_TARGET_OPENCL )
|
||||
{
|
||||
nextData = &layers[activData->consumers[0].lid];
|
||||
lpNext = LayerPin(activData->consumers[0].lid, 0);
|
||||
if ( !activData->consumers.empty() )
|
||||
{
|
||||
nextData = &layers[activData->consumers[0].lid];
|
||||
lpNext = LayerPin(activData->consumers[0].lid, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -181,7 +181,8 @@ public:
|
||||
inputs_.getUMatVector(inputs);
|
||||
outputs_.getUMatVector(outputs);
|
||||
|
||||
if (inputs[0].dims < 4)
|
||||
if (inputs[0].dims < 4 || (total(shape(outputs[0]), 0, 2) % 4 != 0) ||
|
||||
(total(shape(outputs[0]), 2) % 4 != 0))
|
||||
return false;
|
||||
|
||||
const UMat& inpMat = inputs[0];
|
||||
@ -192,22 +193,19 @@ public:
|
||||
int rows = outputs[i].size[2];
|
||||
int cols = outputs[i].size[3];
|
||||
|
||||
int number = (cols % 8 == 0) ? 8 : ((cols % 4 == 0) ? 4 : 1);
|
||||
String buildopt = format("-DNUM=%d ", number);
|
||||
String kname = format("slice%d", number);
|
||||
ocl::Kernel kernel(kname.c_str(), ocl::dnn::slice_oclsrc, buildopt);
|
||||
size_t global[] = { (size_t)groups * channels, (size_t)rows * cols / number };
|
||||
ocl::Kernel kernel("slice", ocl::dnn::slice_oclsrc);
|
||||
size_t local[] = { 128 };
|
||||
size_t global[] = { (size_t)groups * channels / 4 * local[0] };
|
||||
int idx = 0;
|
||||
kernel.set(idx++, ocl::KernelArg::PtrReadOnly(inpMat));
|
||||
kernel.set(idx++, (int)(inpMat.size[2] * inpMat.size[3]));
|
||||
kernel.set(idx++, (int)inpMat.size[3]);
|
||||
kernel.set(idx++, (int)global[0]);
|
||||
kernel.set(idx++, (int)(rows * cols));
|
||||
kernel.set(idx++, (int)inpMat.size[3]);
|
||||
kernel.set(idx++, (int)cols);
|
||||
kernel.set(idx++, (int)sliceRanges[i][2].start);
|
||||
kernel.set(idx++, (int)sliceRanges[i][3].start);
|
||||
kernel.set(idx++, ocl::KernelArg::PtrWriteOnly(outputs[i]));
|
||||
bool ret = kernel.run(2, global, NULL, false);
|
||||
bool ret = kernel.run(1, global, local, false);
|
||||
if (!ret)
|
||||
return false;
|
||||
}
|
||||
|
@ -44,44 +44,38 @@
|
||||
#define Dtype4 float4
|
||||
#define Dtype8 float8
|
||||
|
||||
#if NUM == 8
|
||||
#define load(src, index) vload8(0, src + index)
|
||||
#define store(vec, dst, index) vstore8(vec, 0, dst + index)
|
||||
#define vec_type Dtype8
|
||||
#define SLICE slice8
|
||||
#elif NUM == 4
|
||||
#define load(src, index) vload4(0, src + index)
|
||||
#define store(vec, dst, index) vstore4(vec, 0, dst + index)
|
||||
#define vec_type Dtype4
|
||||
#define SLICE slice4
|
||||
#elif NUM == 1
|
||||
#define load(src, index) src[index]
|
||||
#define store(vec, dst, index) dst[index] = vec
|
||||
#define vec_type Dtype
|
||||
#define SLICE slice1
|
||||
#endif
|
||||
|
||||
__kernel void SLICE(__global const Dtype* src,
|
||||
__kernel void slice(__global const Dtype* src,
|
||||
const int src_plane_size,
|
||||
const int src_cols,
|
||||
const int channels,
|
||||
const int dst_plane_size,
|
||||
const int src_cols,
|
||||
const int dst_cols,
|
||||
const int row_offset,
|
||||
const int col_offset,
|
||||
__global Dtype* dst)
|
||||
{
|
||||
int x = get_global_id(0);
|
||||
int y = get_global_id(1) * NUM;
|
||||
unsigned int row_gid = get_group_id(0);
|
||||
unsigned int lid = get_local_id(0);
|
||||
const __global Dtype *src_read = src + row_gid * 4 * src_plane_size;
|
||||
__global Dtype *dst_read = dst + row_gid * 4 * dst_plane_size;
|
||||
Dtype4 a0, a1, a2, a3;
|
||||
|
||||
if ((x >= channels) || (y >= dst_plane_size))
|
||||
return;
|
||||
int i = lid;
|
||||
while( i < dst_plane_size / 4)
|
||||
{
|
||||
int row = (4 * i) / dst_cols + row_offset;
|
||||
int col = (4 * i) % dst_cols + col_offset;
|
||||
int src_index = row * src_cols + col;
|
||||
|
||||
int row = y / dst_cols + row_offset;
|
||||
int col = y % dst_cols + col_offset;
|
||||
a0 = vload4(0, src_read + src_index);
|
||||
a1 = vload4(0, src_read + src_index + src_plane_size);
|
||||
a2 = vload4(0, src_read + src_index + 2 * src_plane_size);
|
||||
a3 = vload4(0, src_read + src_index + 3 * src_plane_size);
|
||||
|
||||
int src_index = x * src_plane_size + row * src_cols + col;
|
||||
int dst_index = x * dst_plane_size + y;
|
||||
vec_type val = load(src, src_index);
|
||||
store(val, dst, dst_index);
|
||||
vstore4(a0, i, dst_read);
|
||||
vstore4(a1, i, dst_read + dst_plane_size);
|
||||
vstore4(a2, i, dst_read + 2 * dst_plane_size);
|
||||
vstore4(a3, i, dst_read + 3 * dst_plane_size);
|
||||
|
||||
i += get_local_size(0);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user