new implementation of gpu debayer

* without border extrapolation
* with aligned write
* process 4 pixels per thread in 8u case
This commit is contained in:
Vladislav Vinogradov 2012-08-06 10:58:11 +04:00
parent 5ce896d9ee
commit 8624d18ca5
4 changed files with 273 additions and 144 deletions

View File

@ -1148,6 +1148,9 @@ GPU_PERF_TEST(CvtColor, cv::gpu::DeviceInfo, cv::Size, MatDepth, CvtColorInfo)
cv::gpu::GpuMat src(src_host);
cv::gpu::GpuMat dst;
if (info.code >= cv::COLOR_BayerBG2BGR && info.code <= cv::COLOR_BayerGR2BGR)
info.dcn = 4;
cv::gpu::cvtColor(src, dst, info.code, info.dcn);
TEST_CYCLE()

View File

@ -58,8 +58,10 @@ void cv::gpu::gammaCorrection(const GpuMat&, GpuMat&, bool, Stream&) { throw_nog
namespace cv { namespace gpu {
namespace device
{
template <typename T, int cn>
void Bayer2BGR_gpu(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template <int cn>
void Bayer2BGR_8u_gpu(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template <int cn>
void Bayer2BGR_16u_gpu(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
}
}}
@ -1337,9 +1339,9 @@ namespace
typedef void (*func_t)(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
static const func_t funcs[3][4] =
{
{0,0,Bayer2BGR_gpu<uchar, 3>, Bayer2BGR_gpu<uchar, 4>},
{0,0,Bayer2BGR_8u_gpu<3>, Bayer2BGR_8u_gpu<4>},
{0,0,0,0},
{0,0,Bayer2BGR_gpu<ushort, 3>, Bayer2BGR_gpu<ushort, 4>}
{0,0,Bayer2BGR_16u_gpu<3>, Bayer2BGR_16u_gpu<4>}
};
if (dcn <= 0) dcn = 3;

View File

@ -42,167 +42,286 @@
#include <opencv2/gpu/device/common.hpp>
#include <opencv2/gpu/device/vec_traits.hpp>
#include <opencv2/gpu/device/vec_math.hpp>
#include <opencv2/gpu/device/limits.hpp>
namespace cv { namespace gpu {
namespace device
{
template <class SrcPtr, typename T>
__global__ void Bayer2BGR(const SrcPtr src, PtrStep_<T> dst, const int width, const int height, const bool glob_blue_last, const bool glob_start_with_green)
template <typename D>
__global__ void Bayer2BGR_8u(const PtrStepb src, DevMem2D_<D> dst, const bool blue_last, const bool start_with_green)
{
const int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int s_x = blockIdx.x * blockDim.x + threadIdx.x;
int s_y = blockIdx.y * blockDim.y + threadIdx.y;
if (y >= height)
if (s_y >= dst.rows || (s_x << 2) >= dst.cols)
return;
const bool blue_last = (y & 1) ? !glob_blue_last : glob_blue_last;
const bool start_with_green = (y & 1) ? !glob_start_with_green : glob_start_with_green;
s_y = ::min(::max(s_y, 1), dst.rows - 2);
int x = tx * 2;
uchar4 patch[3][3];
patch[0][1] = ((const uchar4*) src.ptr(s_y - 1))[s_x];
patch[0][0] = ((const uchar4*) src.ptr(s_y - 1))[::max(s_x - 1, 0)];
patch[0][2] = ((const uchar4*) src.ptr(s_y - 1))[::min(s_x + 1, ((dst.cols + 3) >> 2) - 1)];
if (start_with_green)
patch[1][1] = ((const uchar4*) src.ptr(s_y))[s_x];
patch[1][0] = ((const uchar4*) src.ptr(s_y))[::max(s_x - 1, 0)];
patch[1][2] = ((const uchar4*) src.ptr(s_y))[::min(s_x + 1, ((dst.cols + 3) >> 2) - 1)];
patch[2][1] = ((const uchar4*) src.ptr(s_y + 1))[s_x];
patch[2][0] = ((const uchar4*) src.ptr(s_y + 1))[::max(s_x - 1, 0)];
patch[2][2] = ((const uchar4*) src.ptr(s_y + 1))[::min(s_x + 1, ((dst.cols + 3) >> 2) - 1)];
D res0 = VecTraits<D>::all(numeric_limits<uchar>::max());
D res1 = VecTraits<D>::all(numeric_limits<uchar>::max());
D res2 = VecTraits<D>::all(numeric_limits<uchar>::max());
D res3 = VecTraits<D>::all(numeric_limits<uchar>::max());
if ((s_y & 1) ^ start_with_green)
{
--x;
const int t0 = (patch[0][1].x + patch[2][1].x + 1) >> 1;
const int t1 = (patch[1][0].w + patch[1][1].y + 1) >> 1;
if (tx == 0)
const int t2 = (patch[0][1].x + patch[0][1].z + patch[2][1].x + patch[2][1].z + 2) >> 2;
const int t3 = (patch[0][1].y + patch[1][1].x + patch[1][1].z + patch[2][1].y + 2) >> 2;
const int t4 = (patch[0][1].z + patch[2][1].z + 1) >> 1;
const int t5 = (patch[1][1].y + patch[1][1].w + 1) >> 1;
const int t6 = (patch[0][1].z + patch[0][2].x + patch[2][1].z + patch[2][2].x + 2) >> 2;
const int t7 = (patch[0][1].w + patch[1][1].z + patch[1][2].x + patch[2][1].w + 2) >> 2;
if ((s_y & 1) ^ blue_last)
{
const int t0 = (src(y, 1) + src(y + 2, 1) + 1) >> 1;
const int t1 = (src(y + 1, 0) + src(y + 1, 2) + 1) >> 1;
res0.x = t1;
res0.y = patch[1][1].x;
res0.z = t0;
T res;
res.x = blue_last ? t0 : t1;
res.y = src(y + 1, 1);
res.z = blue_last ? t1 : t0;
res1.x = patch[1][1].y;
res1.y = t3;
res1.z = t2;
dst(y + 1, 0) = dst(y + 1, 1) = res;
if (y == 0)
{
dst(0, 0) = dst(0, 1) = res;
}
else if (y == height - 1)
{
dst(height + 1, 0) = dst(height + 1, 1) = res;
}
}
}
res2.x = t5;
res2.y = patch[1][1].z;
res2.z = t4;
if (x >= 0 && x <= width - 2)
{
const int t0 = (src(y, x) + src(y, x + 2) + src(y + 2, x) + src(y + 2, x + 2) + 2) >> 2;
const int t1 = (src(y, x + 1) + src(y + 1, x) + src(y + 1, x + 2) + src(y + 2, x + 1) + 2) >> 2;
const int t2 = (src(y, x + 2) + src(y + 2, x + 2) + 1) >> 1;
const int t3 = (src(y + 1, x + 1) + src(y + 1, x + 3) + 1) >> 1;
T res1, res2;
if (blue_last)
{
res1.x = t0;
res1.y = t1;
res1.z = src(y + 1, x + 1);
res2.x = t2;
res2.y = src(y + 1, x + 2);
res2.z = t3;
res3.x = patch[1][1].w;
res3.y = t7;
res3.z = t6;
}
else
{
res1.x = src(y + 1, x + 1);
res1.y = t1;
res1.z = t0;
res0.x = t0;
res0.y = patch[1][1].x;
res0.z = t1;
res2.x = t3;
res2.y = src(y + 1, x + 2);
res2.z = t2;
}
res1.x = t2;
res1.y = t3;
res1.z = patch[1][1].y;
dst(y + 1, x + 1) = res1;
dst(y + 1, x + 2) = res2;
res2.x = t4;
res2.y = patch[1][1].z;
res2.z = t5;
if (y == 0)
{
dst(0, x + 1) = res1;
dst(0, x + 2) = res2;
if (x == 0)
{
dst(0, 0) = res1;
}
else if (x == width - 2)
{
dst(0, width + 1) = res2;
}
}
else if (y == height - 1)
{
dst(height + 1, x + 1) = res1;
dst(height + 1, x + 2) = res2;
if (x == 0)
{
dst(height + 1, 0) = res1;
}
else if (x == width - 2)
{
dst(height + 1, width + 1) = res2;
}
}
if (x == 0)
{
dst(y + 1, 0) = res1;
}
else if (x == width - 2)
{
dst(y + 1, width + 1) = res2;
res3.x = t6;
res3.y = t7;
res3.z = patch[1][1].w;
}
}
else if (x == width - 1)
else
{
const int t0 = (src(y, x) + src(y, x + 2) + src(y + 2, x) + src(y + 2, x + 2) + 2) >> 2;
const int t1 = (src(y, x + 1) + src(y + 1, x) + src(y + 1, x + 2) + src(y + 2, x + 1) + 2) >> 2;
const int t0 = (patch[0][0].w + patch[0][1].y + patch[2][0].w + patch[2][1].y + 2) >> 2;
const int t1 = (patch[0][1].x + patch[1][0].w + patch[1][1].y + patch[2][1].x + 2) >> 2;
T res;
res.x = blue_last ? t0 : src(y + 1, x + 1);
res.y = t1;
res.z = blue_last ? src(y + 1, x + 1) : t0;
const int t2 = (patch[0][1].y + patch[2][1].y + 1) >> 1;
const int t3 = (patch[1][1].x + patch[1][1].z + 1) >> 1;
dst(y + 1, x + 1) = dst(y + 1, x + 2) = res;
if (y == 0)
const int t4 = (patch[0][1].y + patch[0][1].w + patch[2][1].y + patch[2][1].w + 2) >> 2;
const int t5 = (patch[0][1].z + patch[1][1].y + patch[1][1].w + patch[2][1].z + 2) >> 2;
const int t6 = (patch[0][1].w + patch[2][1].w + 1) >> 1;
const int t7 = (patch[1][1].z + patch[1][2].x + 1) >> 1;
if ((s_y & 1) ^ blue_last)
{
dst(0, x + 1) = dst(0, x + 2) = res;
res0.x = patch[1][1].x;
res0.y = t1;
res0.z = t0;
res1.x = t3;
res1.y = patch[1][1].y;
res1.z = t2;
res2.x = patch[1][1].z;
res2.y = t5;
res2.z = t4;
res3.x = t7;
res3.y = patch[1][1].w;
res3.z = t6;
}
else if (y == height - 1)
else
{
dst(height + 1, x + 1) = dst(height + 1, x + 2) = res;
res0.x = t0;
res0.y = t1;
res0.z = patch[1][1].x;
res1.x = t2;
res1.y = patch[1][1].y;
res1.z = t3;
res2.x = t4;
res2.y = t5;
res2.z = patch[1][1].z;
res3.x = t6;
res3.y = patch[1][1].w;
res3.z = t7;
}
}
const int d_x = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
const int d_y = blockIdx.y * blockDim.y + threadIdx.y;
dst(d_y, d_x) = res0;
if (d_x + 1 < dst.cols)
dst(d_y, d_x + 1) = res1;
if (d_x + 2 < dst.cols)
dst(d_y, d_x + 2) = res2;
if (d_x + 3 < dst.cols)
dst(d_y, d_x + 3) = res3;
}
template <typename T, int cn>
void Bayer2BGR_gpu(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream)
template <typename D>
__global__ void Bayer2BGR_16u(const PtrStepb src, DevMem2D_<D> dst, const bool blue_last, const bool start_with_green)
{
typedef typename TypeVec<T, cn>::vec_type dst_t;
const int s_x = blockIdx.x * blockDim.x + threadIdx.x;
int s_y = blockIdx.y * blockDim.y + threadIdx.y;
const int width = src.cols - 2;
const int height = src.rows - 2;
if (s_y >= dst.rows || (s_x << 1) >= dst.cols)
return;
const dim3 total(divUp(width, 2), height);
s_y = ::min(::max(s_y, 1), dst.rows - 2);
ushort2 patch[3][3];
patch[0][1] = ((const ushort2*) src.ptr(s_y - 1))[s_x];
patch[0][0] = ((const ushort2*) src.ptr(s_y - 1))[::max(s_x - 1, 0)];
patch[0][2] = ((const ushort2*) src.ptr(s_y - 1))[::min(s_x + 1, ((dst.cols + 1) >> 1) - 1)];
patch[1][1] = ((const ushort2*) src.ptr(s_y))[s_x];
patch[1][0] = ((const ushort2*) src.ptr(s_y))[::max(s_x - 1, 0)];
patch[1][2] = ((const ushort2*) src.ptr(s_y))[::min(s_x + 1, ((dst.cols + 1) >> 1) - 1)];
patch[2][1] = ((const ushort2*) src.ptr(s_y + 1))[s_x];
patch[2][0] = ((const ushort2*) src.ptr(s_y + 1))[::max(s_x - 1, 0)];
patch[2][2] = ((const ushort2*) src.ptr(s_y + 1))[::min(s_x + 1, ((dst.cols + 1) >> 1) - 1)];
D res0 = VecTraits<D>::all(numeric_limits<ushort>::max());
D res1 = VecTraits<D>::all(numeric_limits<ushort>::max());
if ((s_y & 1) ^ start_with_green)
{
const int t0 = (patch[0][1].x + patch[2][1].x + 1) >> 1;
const int t1 = (patch[1][0].y + patch[1][1].y + 1) >> 1;
const int t2 = (patch[0][1].x + patch[0][2].x + patch[2][1].x + patch[2][2].x + 2) >> 2;
const int t3 = (patch[0][1].y + patch[1][1].x + patch[1][2].x + patch[2][1].y + 2) >> 2;
if ((s_y & 1) ^ blue_last)
{
res0.x = t1;
res0.y = patch[1][1].x;
res0.z = t0;
res1.x = patch[1][1].y;
res1.y = t3;
res1.z = t2;
}
else
{
res0.x = t0;
res0.y = patch[1][1].x;
res0.z = t1;
res1.x = t2;
res1.y = t3;
res1.z = patch[1][1].y;
}
}
else
{
const int t0 = (patch[0][0].y + patch[0][1].y + patch[2][0].y + patch[2][1].y + 2) >> 2;
const int t1 = (patch[0][1].x + patch[1][0].y + patch[1][1].y + patch[2][1].x + 2) >> 2;
const int t2 = (patch[0][1].y + patch[2][1].y + 1) >> 1;
const int t3 = (patch[1][1].x + patch[1][2].x + 1) >> 1;
if ((s_y & 1) ^ blue_last)
{
res0.x = patch[1][1].x;
res0.y = t1;
res0.z = t0;
res1.x = t3;
res1.y = patch[1][1].y;
res1.z = t2;
}
else
{
res0.x = t0;
res0.y = t1;
res0.z = patch[1][1].x;
res1.x = t2;
res1.y = patch[1][1].y;
res1.z = t3;
}
}
const int d_x = (blockIdx.x * blockDim.x + threadIdx.x) << 1;
const int d_y = blockIdx.y * blockDim.y + threadIdx.y;
dst(d_y, d_x) = res0;
if (d_x + 1 < dst.cols)
dst(d_y, d_x + 1) = res1;
}
template <int cn>
void Bayer2BGR_8u_gpu(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream)
{
typedef typename TypeVec<uchar, cn>::vec_type dst_t;
const dim3 block(32, 8);
const dim3 grid(divUp(total.x, block.x), divUp(total.y, block.y));
const dim3 grid(divUp(dst.cols, 4 * block.x), divUp(dst.rows, block.y));
Bayer2BGR<PtrStep_<T>, dst_t><<<grid, block, 0, stream>>>((DevMem2D_<T>)src, (DevMem2D_<dst_t>)dst, width, height, blue_last, start_with_green);
cudaSafeCall( cudaFuncSetCacheConfig(Bayer2BGR_8u<dst_t>, cudaFuncCachePreferL1) );
Bayer2BGR_8u<dst_t><<<grid, block, 0, stream>>>(src, (DevMem2D_<dst_t>)dst, blue_last, start_with_green);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall( cudaDeviceSynchronize() );
}
template <int cn>
void Bayer2BGR_16u_gpu(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream)
{
typedef typename TypeVec<ushort, cn>::vec_type dst_t;
const dim3 block(32, 8);
const dim3 grid(divUp(dst.cols, 2 * block.x), divUp(dst.rows, block.y));
cudaSafeCall( cudaFuncSetCacheConfig(Bayer2BGR_16u<dst_t>, cudaFuncCachePreferL1) );
Bayer2BGR_16u<dst_t><<<grid, block, 0, stream>>>(src, (DevMem2D_<dst_t>)dst, blue_last, start_with_green);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall( cudaDeviceSynchronize() );
}
template void Bayer2BGR_gpu<uchar, 3>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template void Bayer2BGR_gpu<uchar, 4>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template void Bayer2BGR_gpu<ushort, 3>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template void Bayer2BGR_gpu<ushort, 4>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template void Bayer2BGR_8u_gpu<3>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template void Bayer2BGR_8u_gpu<4>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template void Bayer2BGR_16u_gpu<3>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
template void Bayer2BGR_16u_gpu<4>(DevMem2Db src, DevMem2Db dst, bool blue_last, bool start_with_green, cudaStream_t stream);
}
}}

View File

@ -41,6 +41,8 @@
#include "precomp.hpp"
#ifdef HAVE_CUDA
namespace {
///////////////////////////////////////////////////////////////////////////////////////////////////////
@ -1746,10 +1748,10 @@ TEST_P(CvtColor, RGBA2mRGBA)
TEST_P(CvtColor, BayerBG2BGR)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerBG2BGR);
@ -1757,15 +1759,15 @@ TEST_P(CvtColor, BayerBG2BGR)
cv::Mat dst_gold;
cv::cvtColor(src, dst_gold, cv::COLOR_BayerBG2BGR);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
TEST_P(CvtColor, BayerBG2BGR4)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerBG2BGR, 4);
@ -1779,15 +1781,16 @@ TEST_P(CvtColor, BayerBG2BGR4)
cv::Mat dst3;
cv::cvtColor(dst4, dst3, cv::COLOR_BGRA2BGR);
EXPECT_MAT_NEAR(dst_gold, dst3, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst3(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
TEST_P(CvtColor, BayerGB2BGR)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerGB2BGR);
@ -1795,15 +1798,15 @@ TEST_P(CvtColor, BayerGB2BGR)
cv::Mat dst_gold;
cv::cvtColor(src, dst_gold, cv::COLOR_BayerGB2BGR);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
TEST_P(CvtColor, BayerGB2BGR4)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerGB2BGR, 4);
@ -1817,15 +1820,15 @@ TEST_P(CvtColor, BayerGB2BGR4)
cv::Mat dst3;
cv::cvtColor(dst4, dst3, cv::COLOR_BGRA2BGR);
EXPECT_MAT_NEAR(dst_gold, dst3, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst3(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
TEST_P(CvtColor, BayerRG2BGR)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerRG2BGR);
@ -1833,15 +1836,15 @@ TEST_P(CvtColor, BayerRG2BGR)
cv::Mat dst_gold;
cv::cvtColor(src, dst_gold, cv::COLOR_BayerRG2BGR);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
TEST_P(CvtColor, BayerRG2BGR4)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerRG2BGR, 4);
@ -1855,15 +1858,15 @@ TEST_P(CvtColor, BayerRG2BGR4)
cv::Mat dst3;
cv::cvtColor(dst4, dst3, cv::COLOR_BGRA2BGR);
EXPECT_MAT_NEAR(dst_gold, dst3, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst3(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
TEST_P(CvtColor, BayerGR2BGR)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerGR2BGR);
@ -1871,15 +1874,15 @@ TEST_P(CvtColor, BayerGR2BGR)
cv::Mat dst_gold;
cv::cvtColor(src, dst_gold, cv::COLOR_BayerGR2BGR);
EXPECT_MAT_NEAR(dst_gold, dst, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
TEST_P(CvtColor, BayerGR2BGR4)
{
if (depth != CV_8U && depth != CV_16U)
if ((depth != CV_8U && depth != CV_16U) || useRoi)
return;
cv::Mat src = randomMat(size, CV_8UC1);
cv::Mat src = randomMat(size, depth);
cv::gpu::GpuMat dst;
cv::gpu::cvtColor(loadMat(src, useRoi), dst, cv::COLOR_BayerGR2BGR, 4);
@ -1893,7 +1896,7 @@ TEST_P(CvtColor, BayerGR2BGR4)
cv::Mat dst3;
cv::cvtColor(dst4, dst3, cv::COLOR_BGRA2BGR);
EXPECT_MAT_NEAR(dst_gold, dst3, 0);
EXPECT_MAT_NEAR(dst_gold(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), dst3(cv::Rect(1, 1, dst.cols - 2, dst.rows - 2)), 0);
}
INSTANTIATE_TEST_CASE_P(GPU_ImgProc, CvtColor, testing::Combine(
@ -1943,3 +1946,5 @@ INSTANTIATE_TEST_CASE_P(GPU_ImgProc, SwapChannels, testing::Combine(
WHOLE_SUBMAT));
} // namespace
#endif // HAVE_CUDA