Merge pull request #1626 from ilya-lavrenov:ocl_filters

This commit is contained in:
Andrey Pavlenko 2013-10-18 13:44:40 +04:00 committed by OpenCV Buildbot
commit 97dfd65007
9 changed files with 570 additions and 577 deletions

View File

@ -666,3 +666,17 @@ Performs linear blending of two images.
:param weights2: Weights for second image. Must have tha same size as ``img2`` . Supports only ``CV_32F`` type.
:param result: Destination image.
ocl::medianFilter
--------------------
Blurs an image using the median filter.
.. ocv:function:: void ocl::medianFilter(const oclMat &src, oclMat &dst, int m)
:param src: input ```1-``` or ```4```-channel image; the image depth should be ```CV_8U```, ```CV_32F```.
:param dst: destination array of the same size and type as ```src```.
:param m: aperture linear size; it must be odd and greater than ```1```. Currently only ```3```, ```5``` are supported.
The function smoothes an image using the median filter with the \texttt{m} \times \texttt{m} aperture. Each channel of a multi-channel image is processed independently. In-place operation is supported.

View File

@ -839,11 +839,8 @@ namespace cv
//! Applies a generic geometrical transformation to an image.
// Supports INTER_NEAREST, INTER_LINEAR.
// Map1 supports CV_16SC2, CV_32FC2 types.
// Src supports CV_8UC1, CV_8UC2, CV_8UC4.
CV_EXPORTS void remap(const oclMat &src, oclMat &dst, oclMat &map1, oclMat &map2, int interpolation, int bordertype, const Scalar &value = Scalar());
//! copies 2D array to a larger destination array and pads borders with user-specifiable constant
@ -851,7 +848,7 @@ namespace cv
CV_EXPORTS void copyMakeBorder(const oclMat &src, oclMat &dst, int top, int bottom, int left, int right, int boardtype, const Scalar &value = Scalar());
//! Smoothes image using median filter
// The source 1- or 4-channel image. When m is 3 or 5, the image depth should be CV 8U or CV 32F.
// The source 1- or 4-channel image. m should be 3 or 5, the image depth should be CV_8U or CV_32F.
CV_EXPORTS void medianFilter(const oclMat &src, oclMat &dst, int m);
//! warps the image using affine transformation

View File

@ -197,10 +197,10 @@ static void GPUErode(const oclMat &src, oclMat &dst, oclMat &mat_kernel,
(src.rows == dst.rows));
CV_Assert((src.oclchannels() == dst.oclchannels()));
int srcStep = src.step1() / src.oclchannels();
int dstStep = dst.step1() / dst.oclchannels();
int srcOffset = src.offset / src.elemSize();
int dstOffset = dst.offset / dst.elemSize();
int srcStep = src.step / src.elemSize();
int dstStep = dst.step / dst.elemSize();
int srcOffset = src.offset / src.elemSize();
int dstOffset = dst.offset / dst.elemSize();
int srcOffset_x = srcOffset % srcStep;
int srcOffset_y = srcOffset / srcStep;
@ -247,6 +247,7 @@ static void GPUErode(const oclMat &src, oclMat &dst, oclMat &mat_kernel,
sprintf(compile_option, "-D RADIUSX=%d -D RADIUSY=%d -D LSIZE0=%d -D LSIZE1=%d -D ERODE %s %s",
anchor.x, anchor.y, (int)localThreads[0], (int)localThreads[1],
s, rectKernel?"-D RECTKERNEL":"");
vector< pair<size_t, const void *> > args;
args.push_back(make_pair(sizeof(cl_mem), (void *)&src.data));
args.push_back(make_pair(sizeof(cl_mem), (void *)&dst.data));
@ -260,6 +261,7 @@ static void GPUErode(const oclMat &src, oclMat &dst, oclMat &mat_kernel,
args.push_back(make_pair(sizeof(cl_int), (void *)&src.wholecols));
args.push_back(make_pair(sizeof(cl_int), (void *)&src.wholerows));
args.push_back(make_pair(sizeof(cl_int), (void *)&dstOffset));
openCLExecuteKernel(clCxt, &filtering_morph, kernelName, globalThreads, localThreads, args, -1, -1, compile_option);
}
@ -351,7 +353,7 @@ Ptr<BaseFilter_GPU> cv::ocl::getMorphologyFilter_GPU(int op, int type, const Mat
};
CV_Assert(op == MORPH_ERODE || op == MORPH_DILATE);
CV_Assert(type == CV_8UC1 || type == CV_8UC3 || type == CV_8UC4 || type == CV_32FC1 || type == CV_32FC1 || type == CV_32FC4);
CV_Assert(type == CV_8UC1 || type == CV_8UC3 || type == CV_8UC4 || type == CV_32FC1 || type == CV_32FC3 || type == CV_32FC4);
oclMat gpu_krnl;
normalizeKernel(kernel, gpu_krnl);
@ -361,9 +363,11 @@ Ptr<BaseFilter_GPU> cv::ocl::getMorphologyFilter_GPU(int op, int type, const Mat
for(int i = 0; i < kernel.rows * kernel.cols; ++i)
if(kernel.data[i] != 1)
noZero = false;
MorphFilter_GPU* mfgpu=new MorphFilter_GPU(ksize, anchor, gpu_krnl, GPUMorfFilter_callers[op][CV_MAT_CN(type)]);
MorphFilter_GPU* mfgpu = new MorphFilter_GPU(ksize, anchor, gpu_krnl, GPUMorfFilter_callers[op][CV_MAT_CN(type)]);
if(noZero)
mfgpu->rectKernel = true;
return Ptr<BaseFilter_GPU>(mfgpu);
}
@ -445,9 +449,7 @@ void morphOp(int op, const oclMat &src, oclMat &dst, const Mat &_kernel, Point a
iterations = 1;
}
else
{
kernel = _kernel;
}
Ptr<FilterEngine_GPU> f = createMorphologyFilter_GPU(op, src.type(), kernel, anchor, iterations);
@ -462,14 +464,10 @@ void cv::ocl::erode(const oclMat &src, oclMat &dst, const Mat &kernel, Point anc
for (int i = 0; i < kernel.rows * kernel.cols; ++i)
if (kernel.data[i] != 0)
{
allZero = false;
}
if (allZero)
{
kernel.data[0] = 1;
}
morphOp(MORPH_ERODE, src, dst, kernel, anchor, iterations, borderType, borderValue);
}
@ -558,7 +556,7 @@ static void GPUFilter2D(const oclMat &src, oclMat &dst, const oclMat &mat_kernel
Context *clCxt = src.clCxt;
int filterWidth = ksize.width;
bool ksize_3x3 = filterWidth == 3 && src.type() != CV_32FC4; // CV_32FC4 is not tuned up with filter2d_3x3 kernel
bool ksize_3x3 = filterWidth == 3 && src.type() != CV_32FC4 && src.type() != CV_32FC3; // CV_32FC4 is not tuned up with filter2d_3x3 kernel
string kernelName = ksize_3x3 ? "filter2D_3x3" : "filter2D";
@ -649,9 +647,7 @@ Ptr<BaseFilter_GPU> cv::ocl::getLinearFilter_GPU(int srcType, int dstType, const
Ptr<FilterEngine_GPU> cv::ocl::createLinearFilter_GPU(int srcType, int dstType, const Mat &kernel, const Point &anchor,
int borderType)
{
Size ksize = kernel.size();
Ptr<BaseFilter_GPU> linearFilter = getLinearFilter_GPU(srcType, dstType, kernel, ksize, anchor, borderType);
return createFilter2D_GPU(linearFilter);
@ -659,11 +655,8 @@ Ptr<FilterEngine_GPU> cv::ocl::createLinearFilter_GPU(int srcType, int dstType,
void cv::ocl::filter2D(const oclMat &src, oclMat &dst, int ddepth, const Mat &kernel, Point anchor, int borderType)
{
if (ddepth < 0)
{
ddepth = src.depth();
}
dst.create(src.size(), CV_MAKETYPE(ddepth, src.channels()));
@ -1444,9 +1437,7 @@ Ptr<FilterEngine_GPU> cv::ocl::createGaussianFilter_GPU(int type, Size ksize, do
int depth = CV_MAT_DEPTH(type);
if (sigma2 <= 0)
{
sigma2 = sigma1;
}
// automatic detection of kernel size from sigma
if (ksize.width <= 0 && sigma1 > 0)

View File

@ -408,20 +408,11 @@ namespace cv
void medianFilter(const oclMat &src, oclMat &dst, int m)
{
CV_Assert( m % 2 == 1 && m > 1 );
CV_Assert( m <= 5 || src.depth() == CV_8U );
CV_Assert( src.cols <= dst.cols && src.rows <= dst.rows );
CV_Assert( (src.depth() == CV_8U || src.depth() == CV_32F) && (src.channels() == 1 || src.channels() == 4));
dst.create(src.size(), src.type());
if (src.data == dst.data)
{
oclMat src1;
src.copyTo(src1);
return medianFilter(src1, dst, m);
}
int srcStep = src.step1() / src.oclchannels();
int dstStep = dst.step1() / dst.oclchannels();
int srcOffset = src.offset / src.oclchannels() / src.elemSize1();
int dstOffset = dst.offset / dst.oclchannels() / dst.elemSize1();
int srcStep = src.step / src.elemSize(), dstStep = dst.step / dst.elemSize();
int srcOffset = src.offset / src.elemSize(), dstOffset = dst.offset / dst.elemSize();
Context *clCxt = src.clCxt;
@ -1518,6 +1509,7 @@ namespace cv
float *color_weight = &_color_weight[0];
float *space_weight = &_space_weight[0];
int *space_ofs = &_space_ofs[0];
int dst_step_in_pixel = dst.step / dst.elemSize();
int dst_offset_in_pixel = dst.offset / dst.elemSize();
int temp_step_in_pixel = temp.step / temp.elemSize();
@ -1548,7 +1540,7 @@ namespace cv
if ((dst.type() == CV_8UC1) && ((dst.offset & 3) == 0) && ((dst.cols & 3) == 0))
{
kernelName = "bilateral2";
globalThreads[0] = dst.cols / 4;
globalThreads[0] = dst.cols >> 2;
}
vector<pair<size_t , const void *> > args;
@ -1566,15 +1558,17 @@ namespace cv
args.push_back( make_pair( sizeof(cl_mem), (void *)&oclcolor_weight.data ));
args.push_back( make_pair( sizeof(cl_mem), (void *)&oclspace_weight.data ));
args.push_back( make_pair( sizeof(cl_mem), (void *)&oclspace_ofs.data ));
openCLExecuteKernel(src.clCxt, &imgproc_bilateral, kernelName, globalThreads, localThreads, args, dst.oclchannels(), dst.depth());
}
void bilateralFilter(const oclMat &src, oclMat &dst, int radius, double sigmaclr, double sigmaspc, int borderType)
{
dst.create( src.size(), src.type() );
if ( src.depth() == CV_8U )
oclbilateralFilter_8u( src, dst, radius, sigmaclr, sigmaspc, borderType );
else
CV_Error( CV_StsUnsupportedFormat, "Bilateral filtering is only implemented for 8uimages" );
CV_Error( CV_StsUnsupportedFormat, "Bilateral filtering is only implemented for CV_8U images" );
}
}

View File

@ -169,6 +169,7 @@ __kernel void filter2D(
int globalRow = groupStartRow + localRow;
const int src_offset = mad24(src_offset_y, src_step, src_offset_x);
const int dst_offset = mad24(dst_offset_y, dst_step, dst_offset_x);
#ifdef BORDER_CONSTANT
for(int i = localRow; i < LOCAL_HEIGHT; i += get_local_size(1))
{
@ -208,6 +209,7 @@ __kernel void filter2D(
}
}
#endif
barrier(CLK_LOCAL_MEM_FENCE);
if(globalRow < rows && globalCol < cols)
{
@ -231,6 +233,7 @@ __kernel void filter2D(
//////////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////Macro for define elements number per thread/////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
#define ANX 1
#define ANY 1
@ -249,6 +252,7 @@ __kernel void filter2D(
///////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////8uC1////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
__kernel void filter2D_3x3(
__global T_IMG *src,
__global T_IMG *dst,
@ -359,6 +363,7 @@ __kernel void filter2D_3x3(
}
}
}
if(dst_rows_index < dst_rows_end)
{
T_IMGx4 tmp_dst = CONVERT_TYPEx4(sum);

View File

@ -45,6 +45,7 @@
//BORDER_CONSTANT: iiiiii|abcdefgh|iiiiiii
#define ELEM(i,l_edge,r_edge,elem1,elem2) (i)<(l_edge) | (i) >= (r_edge) ? (elem1) : (elem2)
#ifndef GENTYPE
__kernel void morph_C1_D0(__global const uchar * restrict src,
__global uchar *dst,
int src_offset_x, int src_offset_y,
@ -150,7 +151,9 @@ __kernel void morph_C1_D0(__global const uchar * restrict src,
}
}
}
#else
__kernel void morph(__global const GENTYPE * restrict src,
__global GENTYPE *dst,
int src_offset_x, int src_offset_y,
@ -221,4 +224,5 @@ __kernel void morph(__global const GENTYPE * restrict src,
dst[out_addr] = res;
}
}
#endif

View File

@ -47,25 +47,27 @@ __kernel void bilateral_C1_D0(__global uchar *dst,
__constant float *space_weight,
__constant int *space_ofs)
{
int gidx = get_global_id(0);
int gidy = get_global_id(1);
if((gidy<dst_rows) && (gidx<dst_cols))
int x = get_global_id(0);
int y = get_global_id(1);
if (y < dst_rows && x < dst_cols)
{
int src_addr = mad24(gidy+radius,src_step,gidx+radius);
int dst_addr = mad24(gidy,dst_step,gidx+dst_offset);
int src_index = mad24(y + radius, src_step, x + radius);
int dst_index = mad24(y, dst_step, x + dst_offset);
float sum = 0.f, wsum = 0.f;
int val0 = (int)src[src_addr];
int val0 = (int)src[src_index];
for(int k = 0; k < maxk; k++ )
{
int val = (int)src[src_addr + space_ofs[k]];
float w = space_weight[k]*color_weight[abs(val - val0)];
sum += (float)(val)*w;
int val = (int)src[src_index + space_ofs[k]];
float w = space_weight[k] * color_weight[abs(val - val0)];
sum += (float)(val) * w;
wsum += w;
}
dst[dst_addr] = convert_uchar_rtz(sum/wsum+0.5f);
dst[dst_index] = convert_uchar_rtz(sum / wsum + 0.5f);
}
}
__kernel void bilateral2_C1_D0(__global uchar *dst,
__global const uchar *src,
const int dst_rows,
@ -81,25 +83,28 @@ __kernel void bilateral2_C1_D0(__global uchar *dst,
__constant float *space_weight,
__constant int *space_ofs)
{
int gidx = get_global_id(0)<<2;
int gidy = get_global_id(1);
if((gidy<dst_rows) && (gidx<dst_cols))
int x = get_global_id(0) << 2;
int y = get_global_id(1);
if (y < dst_rows && x < dst_cols)
{
int src_addr = mad24(gidy+radius,src_step,gidx+radius);
int dst_addr = mad24(gidy,dst_step,gidx+dst_offset);
int src_index = mad24(y + radius, src_step, x + radius);
int dst_index = mad24(y, dst_step, x + dst_offset);
float4 sum = (float4)(0.f), wsum = (float4)(0.f);
int4 val0 = convert_int4(vload4(0,src+src_addr));
int4 val0 = convert_int4(vload4(0,src + src_index));
for(int k = 0; k < maxk; k++ )
{
int4 val = convert_int4(vload4(0,src+src_addr + space_ofs[k]));
float4 w = (float4)(space_weight[k])*(float4)(color_weight[abs(val.x - val0.x)],color_weight[abs(val.y - val0.y)],color_weight[abs(val.z - val0.z)],color_weight[abs(val.w - val0.w)]);
sum += convert_float4(val)*w;
int4 val = convert_int4(vload4(0,src+src_index + space_ofs[k]));
float4 w = (float4)(space_weight[k]) * (float4)(color_weight[abs(val.x - val0.x)], color_weight[abs(val.y - val0.y)],
color_weight[abs(val.z - val0.z)], color_weight[abs(val.w - val0.w)]);
sum += convert_float4(val) * w;
wsum += w;
}
*(__global uchar4*)(dst+dst_addr) = convert_uchar4_rtz(sum/wsum+0.5f);
*(__global uchar4*)(dst+dst_index) = convert_uchar4_rtz(sum/wsum+0.5f);
}
}
__kernel void bilateral_C4_D0(__global uchar4 *dst,
__global const uchar4 *src,
const int dst_rows,
@ -115,24 +120,26 @@ __kernel void bilateral_C4_D0(__global uchar4 *dst,
__constant float *space_weight,
__constant int *space_ofs)
{
int gidx = get_global_id(0);
int gidy = get_global_id(1);
if((gidy<dst_rows) && (gidx<dst_cols))
int x = get_global_id(0);
int y = get_global_id(1);
if (y < dst_rows && x < dst_cols)
{
int src_addr = mad24(gidy+radius,src_step,gidx+radius);
int dst_addr = mad24(gidy,dst_step,gidx+dst_offset);
int src_index = mad24(y + radius, src_step, x + radius);
int dst_index = mad24(y, dst_step, x + dst_offset);
float4 sum = (float4)0.f;
float wsum = 0.f;
int4 val0 = convert_int4(src[src_addr]);
int4 val0 = convert_int4(src[src_index]);
for(int k = 0; k < maxk; k++ )
{
int4 val = convert_int4(src[src_addr + space_ofs[k]]);
float w = space_weight[k]*color_weight[abs(val.x - val0.x)+abs(val.y - val0.y)+abs(val.z - val0.z)];
sum += convert_float4(val)*(float4)w;
int4 val = convert_int4(src[src_index + space_ofs[k]]);
float w = space_weight[k] * color_weight[abs(val.x - val0.x) + abs(val.y - val0.y) + abs(val.z - val0.z)];
sum += convert_float4(val) * (float4)w;
wsum += w;
}
wsum=1.f/wsum;
dst[dst_addr] = convert_uchar4_rtz(sum*(float4)wsum+(float4)0.5f);
wsum = 1.f / wsum;
dst[dst_index] = convert_uchar4_rtz(sum * (float4)wsum + (float4)0.5f);
}
}

View File

@ -52,424 +52,394 @@
#ifdef HAVE_OPENCL
using namespace cvtest;
using namespace testing;
using namespace std;
using namespace cv;
PARAM_TEST_CASE(FilterTestBase,
MatType,
cv::Size, // kernel size
cv::Size, // dx,dy
int // border type, or iteration
)
PARAM_TEST_CASE(FilterTestBase, MatType,
int, // kernel size
Size, // dx, dy
int, // border type, or iteration
bool) // roi or not
{
//src mat
cv::Mat mat1;
cv::Mat dst;
int type, borderType;
int ksize;
bool useRoi;
// set up roi
int roicols;
int roirows;
int src1x;
int src1y;
int dstx;
int dsty;
Mat src, dst_whole, src_roi, dst_roi;
ocl::oclMat gsrc_whole, gsrc_roi, gdst_whole, gdst_roi;
//src mat with roi
cv::Mat mat1_roi;
cv::Mat dst_roi;
//ocl dst mat for testing
cv::ocl::oclMat gdst_whole;
//ocl mat with roi
cv::ocl::oclMat gmat1;
cv::ocl::oclMat gdst;
virtual void SetUp()
{
type = GET_PARAM(0);
ksize = GET_PARAM(1);
borderType = GET_PARAM(3);
useRoi = GET_PARAM(4);
}
void random_roi()
{
#ifdef RANDOMROI
//randomize ROI
roicols = rng.uniform(2, mat1.cols);
roirows = rng.uniform(2, mat1.rows);
src1x = rng.uniform(0, mat1.cols - roicols);
src1y = rng.uniform(0, mat1.rows - roirows);
dstx = rng.uniform(0, dst.cols - roicols);
dsty = rng.uniform(0, dst.rows - roirows);
#else
roicols = mat1.cols;
roirows = mat1.rows;
src1x = 0;
src1y = 0;
dstx = 0;
dsty = 0;
#endif
Size roiSize = randomSize(1, MAX_VALUE);
Border srcBorder = randomBorder(0, useRoi ? MAX_VALUE : 0);
randomSubMat(src, src_roi, roiSize, srcBorder, type, 5, 256);
mat1_roi = mat1(Rect(src1x, src1y, roicols, roirows));
dst_roi = dst(Rect(dstx, dsty, roicols, roirows));
Border dstBorder = randomBorder(0, useRoi ? MAX_VALUE : 0);
randomSubMat(dst_whole, dst_roi, roiSize, dstBorder, type, 5, 16);
gdst_whole = dst;
gdst = gdst_whole(Rect(dstx, dsty, roicols, roirows));
gmat1 = mat1_roi;
generateOclMat(gsrc_whole, gsrc_roi, src, roiSize, srcBorder);
generateOclMat(gdst_whole, gdst_roi, dst_whole, roiSize, dstBorder);
}
void Init(int mat_type)
void Near(double threshold = 0.0)
{
cv::Size size(MWIDTH, MHEIGHT);
mat1 = randomMat(size, mat_type, 5, 16);
dst = randomMat(size, mat_type, 5, 16);
}
void Near(double threshold)
{
EXPECT_MAT_NEAR(dst, Mat(gdst_whole), threshold);
EXPECT_MAT_NEAR(dst_whole, Mat(gdst_whole), threshold);
EXPECT_MAT_NEAR(dst_roi, Mat(gdst_roi), threshold);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// blur
struct Blur : FilterTestBase
{
int type;
cv::Size ksize;
int bordertype;
virtual void SetUp()
{
type = GET_PARAM(0);
ksize = GET_PARAM(1);
bordertype = GET_PARAM(3);
Init(type);
}
};
typedef FilterTestBase Blur;
OCL_TEST_P(Blur, Mat)
{
for(int j = 0; j < LOOP_TIMES; j++)
Size kernelSize(ksize, ksize);
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::blur(mat1_roi, dst_roi, ksize, Point(-1, -1), bordertype);
cv::ocl::blur(gmat1, gdst, ksize, Point(-1, -1), bordertype);
blur(src_roi, dst_roi, kernelSize, Point(-1, -1), borderType);
ocl::blur(gsrc_roi, gdst_roi, kernelSize, Point(-1, -1), borderType); // TODO anchor
Near(1.0);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
//Laplacian
struct Laplacian : FilterTestBase
{
int type;
cv::Size ksize;
// Laplacian
virtual void SetUp()
{
type = GET_PARAM(0);
ksize = GET_PARAM(1);
Init(type);
}
};
typedef FilterTestBase LaplacianTest;
OCL_TEST_P(Laplacian, Accuracy)
OCL_TEST_P(LaplacianTest, Accuracy)
{
for(int j = 0; j < LOOP_TIMES; j++)
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::Laplacian(mat1_roi, dst_roi, -1, ksize.width, 1);
cv::ocl::Laplacian(gmat1, gdst, -1, ksize.width, 1);
Laplacian(src_roi, dst_roi, -1, ksize, 1);
ocl::Laplacian(gsrc_roi, gdst_roi, -1, ksize, 1); // TODO scale
Near(1e-5);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// erode & dilate
struct ErodeDilate : FilterTestBase
{
int type;
int iterations;
//erode or dilate kernel
cv::Mat kernel;
struct ErodeDilate :
public FilterTestBase
{
int iterations;
virtual void SetUp()
{
type = GET_PARAM(0);
iterations = GET_PARAM(3);
Init(type);
kernel = randomMat(Size(3, 3), CV_8UC1, 0, 3);
useRoi = GET_PARAM(4);
}
};
OCL_TEST_P(ErodeDilate, Mat)
typedef ErodeDilate Erode;
OCL_TEST_P(Erode, Mat)
{
for(int j = 0; j < LOOP_TIMES; j++)
// erode or dilate kernel
Size kernelSize(ksize, ksize);
Mat kernel;
for (int j = 0; j < LOOP_TIMES; j++)
{
kernel = randomMat(kernelSize, CV_8UC1, 0, 3);
random_roi();
cv::erode(mat1_roi, dst_roi, kernel, Point(-1, -1), iterations);
cv::ocl::erode(gmat1, gdst, kernel, Point(-1, -1), iterations);
Near(1e-5);
}
for(int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::dilate(mat1_roi, dst_roi, kernel, Point(-1, -1), iterations);
cv::ocl::dilate(gmat1, gdst, kernel, Point(-1, -1), iterations);
cv::erode(src_roi, dst_roi, kernel, Point(-1, -1), iterations);
ocl::erode(gsrc_roi, gdst_roi, kernel, Point(-1, -1), iterations); // TODO iterations, borderType
Near(1e-5);
}
}
typedef ErodeDilate Dilate;
OCL_TEST_P(Dilate, Mat)
{
// erode or dilate kernel
Mat kernel;
for (int j = 0; j < LOOP_TIMES; j++)
{
kernel = randomMat(Size(3, 3), CV_8UC1, 0, 3);
random_roi();
cv::dilate(src_roi, dst_roi, kernel, Point(-1, -1), iterations);
ocl::dilate(gsrc_roi, gdst_roi, kernel, Point(-1, -1), iterations); // TODO iterations, borderType
Near(1e-5);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Sobel
struct Sobel : FilterTestBase
struct SobelTest :
public FilterTestBase
{
int type;
int dx, dy, ksize, bordertype;
int dx, dy;
virtual void SetUp()
{
type = GET_PARAM(0);
Size s = GET_PARAM(1);
ksize = s.width;
s = GET_PARAM(2);
dx = s.width;
dy = s.height;
bordertype = GET_PARAM(3);
Init(type);
ksize = GET_PARAM(1);
borderType = GET_PARAM(3);
useRoi = GET_PARAM(4);
Size d = GET_PARAM(2);
dx = d.width, dy = d.height;
}
};
OCL_TEST_P(Sobel, Mat)
OCL_TEST_P(SobelTest, Mat)
{
for(int j = 0; j < LOOP_TIMES; j++)
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::Sobel(mat1_roi, dst_roi, -1, dx, dy, ksize, /*scale*/0.00001,/*delta*/0, bordertype);
cv::ocl::Sobel(gmat1, gdst, -1, dx, dy, ksize,/*scale*/0.00001,/*delta*/0, bordertype);
Sobel(src_roi, dst_roi, -1, dx, dy, ksize, /* scale */ 0.00001, /* delta */0, borderType);
ocl::Sobel(gsrc_roi, gdst_roi, -1, dx, dy, ksize, /* scale */ 0.00001, /* delta */ 0, borderType);
Near(1);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Scharr
struct Scharr : FilterTestBase
{
int type;
int dx, dy, bordertype;
virtual void SetUp()
{
type = GET_PARAM(0);
Size s = GET_PARAM(2);
dx = s.width;
dy = s.height;
bordertype = GET_PARAM(3);
Init(type);
}
};
typedef SobelTest ScharrTest;
OCL_TEST_P(Scharr, Mat)
OCL_TEST_P(ScharrTest, Mat)
{
for(int j = 0; j < LOOP_TIMES; j++)
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::Scharr(mat1_roi, dst_roi, -1, dx, dy, /*scale*/1,/*delta*/0, bordertype);
cv::ocl::Scharr(gmat1, gdst, -1, dx, dy,/*scale*/1,/*delta*/0, bordertype);
Scharr(src_roi, dst_roi, -1, dx, dy, /* scale */ 1, /* delta */ 0, borderType);
ocl::Scharr(gsrc_roi, gdst_roi, -1, dx, dy, /* scale */ 1, /* delta */ 0, borderType);
Near(1);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// GaussianBlur
struct GaussianBlur : FilterTestBase
struct GaussianBlurTest :
public FilterTestBase
{
int type;
cv::Size ksize;
int bordertype;
double sigma1, sigma2;
virtual void SetUp()
{
type = GET_PARAM(0);
ksize = GET_PARAM(1);
bordertype = GET_PARAM(3);
Init(type);
borderType = GET_PARAM(3);
sigma1 = rng.uniform(0.1, 1.0);
sigma2 = rng.uniform(0.1, 1.0);
}
};
OCL_TEST_P(GaussianBlur, Mat)
OCL_TEST_P(GaussianBlurTest, Mat)
{
for(int j = 0; j < LOOP_TIMES; j++)
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::GaussianBlur(mat1_roi, dst_roi, ksize, sigma1, sigma2, bordertype);
cv::ocl::GaussianBlur(gmat1, gdst, ksize, sigma1, sigma2, bordertype);
GaussianBlur(src_roi, dst_roi, Size(ksize, ksize), sigma1, sigma2, borderType);
ocl::GaussianBlur(gsrc_roi, gdst_roi, Size(ksize, ksize), sigma1, sigma2, borderType);
Near(1);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Filter2D
struct Filter2D : FilterTestBase
{
int type;
cv::Size ksize;
int bordertype;
Point anchor;
virtual void SetUp()
{
type = GET_PARAM(0);
ksize = GET_PARAM(1);
bordertype = GET_PARAM(3);
Init(type);
anchor = Point(-1,-1);
}
};
typedef FilterTestBase Filter2D;
OCL_TEST_P(Filter2D, Mat)
{
cv::Mat kernel = randomMat(cv::Size(ksize.width, ksize.height), CV_32FC1, 0.0, 1.0);
for(int j = 0; j < LOOP_TIMES; j++)
const Size kernelSize(ksize, ksize);
Mat kernel;
for (int j = 0; j < LOOP_TIMES; j++)
{
kernel = randomMat(kernelSize, CV_32FC1, 0.0, 1.0);
random_roi();
cv::filter2D(mat1_roi, dst_roi, -1, kernel, anchor, 0.0, bordertype);
cv::ocl::filter2D(gmat1, gdst, -1, kernel, anchor, bordertype);
cv::filter2D(src_roi, dst_roi, -1, kernel, Point(-1, -1), 0.0, borderType); // TODO anchor
ocl::filter2D(gsrc_roi, gdst_roi, -1, kernel, Point(-1, -1), borderType);
Near(1);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Bilateral
struct Bilateral : FilterTestBase
{
int type;
cv::Size ksize;
int bordertype;
double sigmacolor, sigmaspace;
virtual void SetUp()
{
type = GET_PARAM(0);
ksize = GET_PARAM(1);
bordertype = GET_PARAM(3);
Init(type);
sigmacolor = rng.uniform(20, 100);
sigmaspace = rng.uniform(10, 40);
}
};
typedef FilterTestBase Bilateral;
OCL_TEST_P(Bilateral, Mat)
{
for(int j = 0; j < LOOP_TIMES; j++)
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::bilateralFilter(mat1_roi, dst_roi, ksize.width, sigmacolor, sigmaspace, bordertype);
cv::ocl::bilateralFilter(gmat1, gdst, ksize.width, sigmacolor, sigmaspace, bordertype);
double sigmacolor = rng.uniform(20, 100);
double sigmaspace = rng.uniform(10, 40);
cv::bilateralFilter(src_roi, dst_roi, ksize, sigmacolor, sigmaspace, borderType);
ocl::bilateralFilter(gsrc_roi, gdst_roi, ksize, sigmacolor, sigmaspace, borderType);
Near(1);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// AdaptiveBilateral
struct AdaptiveBilateral : FilterTestBase
{
int type;
cv::Size ksize;
int bordertype;
Point anchor;
virtual void SetUp()
{
type = GET_PARAM(0);
ksize = GET_PARAM(1);
bordertype = GET_PARAM(3);
Init(type);
anchor = Point(-1,-1);
}
};
typedef FilterTestBase AdaptiveBilateral;
OCL_TEST_P(AdaptiveBilateral, Mat)
{
for(int j = 0; j < LOOP_TIMES; j++)
const Size kernelSize(ksize, ksize);
for (int j = 0; j < LOOP_TIMES; j++)
{
random_roi();
cv::adaptiveBilateralFilter(mat1_roi, dst_roi, ksize, 5, anchor, bordertype);
cv::ocl::adaptiveBilateralFilter(gmat1, gdst, ksize, 5, anchor, bordertype);
adaptiveBilateralFilter(src_roi, dst_roi, kernelSize, 5, Point(-1, -1), borderType); // TODO anchor
ocl::adaptiveBilateralFilter(gsrc_roi, gdst_roi, kernelSize, 5, Point(-1, -1), borderType);
Near(1);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////
// MedianFilter
typedef FilterTestBase MedianFilter;
OCL_TEST_P(MedianFilter, Mat)
{
for (int i = 0; i < LOOP_TIMES; ++i)
{
random_roi();
medianBlur(src_roi, dst_roi, ksize);
ocl::medianFilter(gsrc_roi, gdst_roi, ksize);
Near();
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
INSTANTIATE_TEST_CASE_P(Filter, Blur, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC4),
Values(cv::Size(3, 3), cv::Size(5, 5), cv::Size(7, 7)),
Values(Size(0, 0)), //not use
Values((MatType)cv::BORDER_CONSTANT, (MatType)cv::BORDER_REPLICATE, (MatType)cv::BORDER_REFLECT, (MatType)cv::BORDER_REFLECT_101)));
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC4),
Values(3, 5, 7),
Values(Size(0, 0)), // not used
Values((int)BORDER_CONSTANT, (int)BORDER_REPLICATE, (int)BORDER_REFLECT, (int)BORDER_REFLECT_101),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, LaplacianTest, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC3, CV_32FC4),
Values(1, 3),
Values(Size(0, 0)), // not used
Values(0), // not used
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, Laplacian, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC3, CV_32FC4),
Values(Size(3, 3)),
Values(Size(0, 0)), //not use
Values(0))); //not use
INSTANTIATE_TEST_CASE_P(Filter, Erode, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC3, CV_32FC4),
Values(3, 5, 7),
Values(Size(0, 0)), // not used
testing::Range(1, 2),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, ErodeDilate, Combine(
Values(CV_8UC1, CV_8UC4, CV_32FC1, CV_32FC4),
Values(Size(0, 0)), //not use
Values(Size(0, 0)), //not use
Values(1)));
INSTANTIATE_TEST_CASE_P(Filter, Dilate, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC3, CV_32FC4),
Values(3, 5, 7),
Values(Size(0, 0)), // not used
testing::Range(1, 2),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, SobelTest, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC3, CV_32FC4),
Values(3, 5),
Values(Size(1, 0), Size(1, 1), Size(2, 0), Size(2, 1)),
Values((int)BORDER_CONSTANT, (int)BORDER_REFLECT101,
(int)BORDER_REPLICATE, (int)BORDER_REFLECT),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, Sobel, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC3, CV_32FC4),
Values(Size(3, 3), Size(5, 5)),
Values(Size(1, 0), Size(1, 1), Size(2, 0), Size(2, 1)),
Values((MatType)cv::BORDER_CONSTANT, (MatType)cv::BORDER_REPLICATE)));
INSTANTIATE_TEST_CASE_P(Filter, Scharr, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC4),
Values(Size(0, 0)), //not use
Values(Size(0, 1), Size(1, 0)),
Values((MatType)cv::BORDER_CONSTANT, (MatType)cv::BORDER_REPLICATE)));
INSTANTIATE_TEST_CASE_P(Filter, GaussianBlur, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC4),
Values(Size(3, 3), Size(5, 5)),
Values(Size(0, 0)), //not use
Values((MatType)cv::BORDER_CONSTANT, (MatType)cv::BORDER_REPLICATE)));
INSTANTIATE_TEST_CASE_P(Filter, ScharrTest, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC4),
Values(0), // not used
Values(Size(0, 1), Size(1, 0)),
Values((int)BORDER_CONSTANT, (int)BORDER_REFLECT101,
(int)BORDER_REPLICATE, (int)BORDER_REFLECT),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, GaussianBlurTest, Combine(
Values(CV_8UC1, CV_8UC3, CV_8UC4, CV_32FC1, CV_32FC4),
Values(3, 5),
Values(Size(0, 0)), // not used
Values((int)BORDER_CONSTANT, (int)BORDER_REFLECT101,
(int)BORDER_REPLICATE, (int)BORDER_REFLECT),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, Filter2D, testing::Combine(
Values(CV_8UC1, CV_32FC1, CV_32FC4),
Values(Size(3, 3), Size(15, 15), Size(25, 25)),
Values(Size(0, 0)), //not use
Values((MatType)cv::BORDER_CONSTANT, (MatType)cv::BORDER_REFLECT101, (MatType)cv::BORDER_REPLICATE, (MatType)cv::BORDER_REFLECT)));
Values(CV_8UC1, CV_32FC1, CV_32FC4),
Values(3, 15, 25),
Values(Size(0, 0)), // not used
Values((int)BORDER_CONSTANT, (int)BORDER_REFLECT101,
(int)BORDER_REPLICATE, (int)BORDER_REFLECT),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, Bilateral, Combine(
Values(CV_8UC1, CV_8UC3),
Values(Size(5, 5), Size(9, 9)),
Values(Size(0, 0)), //not use
Values((MatType)cv::BORDER_CONSTANT, (MatType)cv::BORDER_REPLICATE,
(MatType)cv::BORDER_REFLECT, (MatType)cv::BORDER_WRAP, (MatType)cv::BORDER_REFLECT_101)));
Values(CV_8UC1, CV_8UC3),
Values(5, 9),
Values(Size(0, 0)), // not used
Values((int)BORDER_CONSTANT, (int)BORDER_REPLICATE,
(int)BORDER_REFLECT, (int)BORDER_WRAP, (int)BORDER_REFLECT_101),
Values(false))); // TODO does not work with ROI
INSTANTIATE_TEST_CASE_P(Filter, AdaptiveBilateral, Combine(
Values(CV_8UC1, CV_8UC3),
Values(Size(5, 5), Size(9, 9)),
Values(Size(0, 0)), //not use
Values((MatType)cv::BORDER_CONSTANT, (MatType)cv::BORDER_REPLICATE,
(MatType)cv::BORDER_REFLECT, (MatType)cv::BORDER_REFLECT_101)));
Values(CV_8UC1, CV_8UC3),
Values(5, 9),
Values(Size(0, 0)), // not used
Values((int)BORDER_CONSTANT, (int)BORDER_REPLICATE,
(int)BORDER_REFLECT, (int)BORDER_REFLECT_101),
Bool()));
INSTANTIATE_TEST_CASE_P(Filter, MedianFilter, Combine(
Values((MatType)CV_8UC1, (MatType)CV_8UC4, (MatType)CV_32FC1, (MatType)CV_32FC4),
Values(3, 5),
Values(Size(0, 0)), // not used
Values(0), // not used
Bool()));
#endif // HAVE_OPENCL

File diff suppressed because it is too large Load Diff