diff --git a/modules/gpu/include/opencv2/gpu/device/scan.hpp b/modules/gpu/include/opencv2/gpu/device/scan.hpp index f6dc6937fb..3d8da16f84 100644 --- a/modules/gpu/include/opencv2/gpu/device/scan.hpp +++ b/modules/gpu/include/opencv2/gpu/device/scan.hpp @@ -43,7 +43,10 @@ #ifndef __OPENCV_GPU_SCAN_HPP__ #define __OPENCV_GPU_SCAN_HPP__ -#include "common.hpp" +#include "opencv2/gpu/device/common.hpp" +#include "opencv2/gpu/device/utility.hpp" +#include "opencv2/gpu/device/warp.hpp" +#include "opencv2/gpu/device/warp_shuffle.hpp" namespace cv { namespace gpu { namespace device { @@ -166,6 +169,82 @@ namespace cv { namespace gpu { namespace device static const int warp_log = 5; static const int warp_mask = 31; }; + + template + __device__ T warpScanInclusive(T idata, volatile T* s_Data, unsigned int tid) + { + #if __CUDA_ARCH__ >= 300 + const unsigned int laneId = cv::gpu::device::Warp::laneId(); + + // scan on shuffl functions + #pragma unroll + for (int i = 1; i <= (OPENCV_GPU_WARP_SIZE / 2); i *= 2) + { + const T n = cv::gpu::device::shfl_up(idata, i); + if (laneId >= i) + idata += n; + } + + return idata; + #else + unsigned int pos = 2 * tid - (tid & (OPENCV_GPU_WARP_SIZE - 1)); + s_Data[pos] = 0; + pos += OPENCV_GPU_WARP_SIZE; + s_Data[pos] = idata; + + s_Data[pos] += s_Data[pos - 1]; + s_Data[pos] += s_Data[pos - 2]; + s_Data[pos] += s_Data[pos - 4]; + s_Data[pos] += s_Data[pos - 8]; + s_Data[pos] += s_Data[pos - 16]; + + return s_Data[pos]; + #endif + } + + template + __device__ __forceinline__ T warpScanExclusive(T idata, volatile T* s_Data, unsigned int tid) + { + return warpScanInclusive(idata, s_Data, tid) - idata; + } + + template + __device__ T blockScanInclusive(T idata, volatile T* s_Data, unsigned int tid) + { + if (tiNumScanThreads > OPENCV_GPU_WARP_SIZE) + { + //Bottom-level inclusive warp scan + T warpResult = warpScanInclusive(idata, s_Data, tid); + + //Save top elements of each warp for exclusive warp scan + //sync to wait for warp scans to complete (because s_Data is being overwritten) + __syncthreads(); + if ((tid & (OPENCV_GPU_WARP_SIZE - 1)) == (OPENCV_GPU_WARP_SIZE - 1)) + { + s_Data[tid >> OPENCV_GPU_LOG_WARP_SIZE] = warpResult; + } + + //wait for warp scans to complete + __syncthreads(); + + if (tid < (tiNumScanThreads / OPENCV_GPU_WARP_SIZE) ) + { + //grab top warp elements + T val = s_Data[tid]; + //calculate exclusive scan and write back to shared memory + s_Data[tid] = warpScanExclusive(val, s_Data, tid); + } + + //return updated warp scans with exclusive scan results + __syncthreads(); + + return warpResult + s_Data[tid >> OPENCV_GPU_LOG_WARP_SIZE]; + } + else + { + return warpScanInclusive(idata, s_Data, tid); + } + } }}} #endif // __OPENCV_GPU_SCAN_HPP__ diff --git a/modules/gpu/include/opencv2/gpu/gpu.hpp b/modules/gpu/include/opencv2/gpu/gpu.hpp index bd7085a004..e2fc99b90f 100644 --- a/modules/gpu/include/opencv2/gpu/gpu.hpp +++ b/modules/gpu/include/opencv2/gpu/gpu.hpp @@ -1062,6 +1062,14 @@ CV_EXPORTS void equalizeHist(const GpuMat& src, GpuMat& dst, Stream& stream = St CV_EXPORTS void equalizeHist(const GpuMat& src, GpuMat& dst, GpuMat& hist, Stream& stream = Stream::Null()); CV_EXPORTS void equalizeHist(const GpuMat& src, GpuMat& dst, GpuMat& hist, GpuMat& buf, Stream& stream = Stream::Null()); +class CV_EXPORTS CLAHE : public cv::CLAHE +{ +public: + using cv::CLAHE::apply; + virtual void apply(InputArray src, OutputArray dst, Stream& stream) = 0; +}; +CV_EXPORTS Ptr createCLAHE(double clipLimit = 40.0, Size tileGridSize = Size(8, 8)); + //////////////////////////////// StereoBM_GPU //////////////////////////////// class CV_EXPORTS StereoBM_GPU diff --git a/modules/gpu/perf/perf_imgproc.cpp b/modules/gpu/perf/perf_imgproc.cpp index 9f4d673594..eff2bfcf2e 100644 --- a/modules/gpu/perf/perf_imgproc.cpp +++ b/modules/gpu/perf/perf_imgproc.cpp @@ -600,6 +600,39 @@ PERF_TEST_P(Sz, ImgProc_EqualizeHist, } } +DEF_PARAM_TEST(Sz_ClipLimit, cv::Size, double); + +PERF_TEST_P(Sz_ClipLimit, ImgProc_CLAHE, + Combine(GPU_TYPICAL_MAT_SIZES, + Values(0.0, 40.0))) +{ + const cv::Size size = GET_PARAM(0); + const double clipLimit = GET_PARAM(1); + + cv::Mat src(size, CV_8UC1); + declare.in(src, WARMUP_RNG); + + if (PERF_RUN_GPU()) + { + cv::Ptr clahe = cv::gpu::createCLAHE(clipLimit); + cv::gpu::GpuMat d_src(src); + cv::gpu::GpuMat dst; + + TEST_CYCLE() clahe->apply(d_src, dst); + + GPU_SANITY_CHECK(dst); + } + else + { + cv::Ptr clahe = cv::createCLAHE(clipLimit); + cv::Mat dst; + + TEST_CYCLE() clahe->apply(src, dst); + + CPU_SANITY_CHECK(dst); + } +} + ////////////////////////////////////////////////////////////////////// // ColumnSum diff --git a/modules/gpu/src/cuda/clahe.cu b/modules/gpu/src/cuda/clahe.cu new file mode 100644 index 0000000000..a0c30d5826 --- /dev/null +++ b/modules/gpu/src/cuda/clahe.cu @@ -0,0 +1,186 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2000-2008, Intel Corporation, all rights reserved. +// Copyright (C) 2009, Willow Garage Inc., all rights reserved. +// Third party copyrights are property of their respective owners. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors "as is" and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#if !defined CUDA_DISABLER + +#include "opencv2/gpu/device/common.hpp" +#include "opencv2/gpu/device/functional.hpp" +#include "opencv2/gpu/device/emulation.hpp" +#include "opencv2/gpu/device/scan.hpp" +#include "opencv2/gpu/device/reduce.hpp" +#include "opencv2/gpu/device/saturate_cast.hpp" + +using namespace cv::gpu; +using namespace cv::gpu::device; + +namespace clahe +{ + __global__ void calcLutKernel(const PtrStepb src, PtrStepb lut, + const int2 tileSize, const int tilesX, + const int clipLimit, const float lutScale) + { + __shared__ int smem[512]; + + const int tx = blockIdx.x; + const int ty = blockIdx.y; + const unsigned int tid = threadIdx.y * blockDim.x + threadIdx.x; + + smem[tid] = 0; + __syncthreads(); + + for (int i = threadIdx.y; i < tileSize.y; i += blockDim.y) + { + const uchar* srcPtr = src.ptr(ty * tileSize.y + i) + tx * tileSize.x; + for (int j = threadIdx.x; j < tileSize.x; j += blockDim.x) + { + const int data = srcPtr[j]; + Emulation::smem::atomicAdd(&smem[data], 1); + } + } + + __syncthreads(); + + int tHistVal = smem[tid]; + + __syncthreads(); + + if (clipLimit > 0) + { + // clip histogram bar + + int clipped = 0; + if (tHistVal > clipLimit) + { + clipped = tHistVal - clipLimit; + tHistVal = clipLimit; + } + + // find number of overall clipped samples + + reduce<256>(smem, clipped, tid, plus()); + + // broadcast evaluated value + + __shared__ int totalClipped; + + if (tid == 0) + totalClipped = clipped; + __syncthreads(); + + // redistribute clipped samples evenly + + int redistBatch = totalClipped / 256; + tHistVal += redistBatch; + + int residual = totalClipped - redistBatch * 256; + if (tid < residual) + ++tHistVal; + } + + const int lutVal = blockScanInclusive<256>(tHistVal, smem, tid); + + lut(ty * tilesX + tx, tid) = saturate_cast(__float2int_rn(lutScale * lutVal)); + } + + void calcLut(PtrStepSzb src, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, cudaStream_t stream) + { + const dim3 block(32, 8); + const dim3 grid(tilesX, tilesY); + + calcLutKernel<<>>(src, lut, tileSize, tilesX, clipLimit, lutScale); + + cudaSafeCall( cudaGetLastError() ); + + if (stream == 0) + cudaSafeCall( cudaDeviceSynchronize() ); + } + + __global__ void tranformKernel(const PtrStepSzb src, PtrStepb dst, const PtrStepb lut, const int2 tileSize, const int tilesX, const int tilesY) + { + const int x = blockIdx.x * blockDim.x + threadIdx.x; + const int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= src.cols || y >= src.rows) + return; + + const float tyf = (static_cast(y) / tileSize.y) - 0.5f; + int ty1 = __float2int_rd(tyf); + int ty2 = ty1 + 1; + const float ya = tyf - ty1; + ty1 = ::max(ty1, 0); + ty2 = ::min(ty2, tilesY - 1); + + const float txf = (static_cast(x) / tileSize.x) - 0.5f; + int tx1 = __float2int_rd(txf); + int tx2 = tx1 + 1; + const float xa = txf - tx1; + tx1 = ::max(tx1, 0); + tx2 = ::min(tx2, tilesX - 1); + + const int srcVal = src(y, x); + + float res = 0; + + res += lut(ty1 * tilesX + tx1, srcVal) * ((1.0f - xa) * (1.0f - ya)); + res += lut(ty1 * tilesX + tx2, srcVal) * ((xa) * (1.0f - ya)); + res += lut(ty2 * tilesX + tx1, srcVal) * ((1.0f - xa) * (ya)); + res += lut(ty2 * tilesX + tx2, srcVal) * ((xa) * (ya)); + + dst(y, x) = saturate_cast(res); + } + + void transform(PtrStepSzb src, PtrStepSzb dst, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream) + { + const dim3 block(32, 8); + const dim3 grid(divUp(src.cols, block.x), divUp(src.rows, block.y)); + + cudaSafeCall( cudaFuncSetCacheConfig(tranformKernel, cudaFuncCachePreferL1) ); + + tranformKernel<<>>(src, dst, lut, tileSize, tilesX, tilesY); + cudaSafeCall( cudaGetLastError() ); + + if (stream == 0) + cudaSafeCall( cudaDeviceSynchronize() ); + } +} + +#endif // CUDA_DISABLER diff --git a/modules/gpu/src/imgproc.cpp b/modules/gpu/src/imgproc.cpp index d9ca46844e..97c7c766c0 100644 --- a/modules/gpu/src/imgproc.cpp +++ b/modules/gpu/src/imgproc.cpp @@ -96,6 +96,7 @@ void cv::gpu::Canny(const GpuMat&, const GpuMat&, GpuMat&, double, double, bool) void cv::gpu::Canny(const GpuMat&, const GpuMat&, CannyBuf&, GpuMat&, double, double, bool) { throw_nogpu(); } void cv::gpu::CannyBuf::create(const Size&, int) { throw_nogpu(); } void cv::gpu::CannyBuf::release() { throw_nogpu(); } +cv::Ptr cv::gpu::createCLAHE(double, cv::Size) { throw_nogpu(); return cv::Ptr(); } #else /* !defined (HAVE_CUDA) */ @@ -1559,4 +1560,136 @@ void cv::gpu::Canny(const GpuMat& dx, const GpuMat& dy, CannyBuf& buf, GpuMat& d CannyCaller(dx, dy, buf, dst, static_cast(low_thresh), static_cast(high_thresh)); } +//////////////////////////////////////////////////////////////////////// +// CLAHE + +namespace clahe +{ + void calcLut(PtrStepSzb src, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, int clipLimit, float lutScale, cudaStream_t stream); + void transform(PtrStepSzb src, PtrStepSzb dst, PtrStepb lut, int tilesX, int tilesY, int2 tileSize, cudaStream_t stream); +} + +namespace +{ + class CLAHE_Impl : public cv::gpu::CLAHE + { + public: + CLAHE_Impl(double clipLimit = 40.0, int tilesX = 8, int tilesY = 8); + + cv::AlgorithmInfo* info() const; + + void apply(cv::InputArray src, cv::OutputArray dst); + void apply(InputArray src, OutputArray dst, Stream& stream); + + void setClipLimit(double clipLimit); + double getClipLimit() const; + + void setTilesGridSize(cv::Size tileGridSize); + cv::Size getTilesGridSize() const; + + void collectGarbage(); + + private: + double clipLimit_; + int tilesX_; + int tilesY_; + + GpuMat srcExt_; + GpuMat lut_; + }; + + CLAHE_Impl::CLAHE_Impl(double clipLimit, int tilesX, int tilesY) : + clipLimit_(clipLimit), tilesX_(tilesX), tilesY_(tilesY) + { + } + + CV_INIT_ALGORITHM(CLAHE_Impl, "CLAHE_GPU", + obj.info()->addParam(obj, "clipLimit", obj.clipLimit_); + obj.info()->addParam(obj, "tilesX", obj.tilesX_); + obj.info()->addParam(obj, "tilesY", obj.tilesY_)) + + void CLAHE_Impl::apply(cv::InputArray _src, cv::OutputArray _dst) + { + apply(_src, _dst, Stream::Null()); + } + + void CLAHE_Impl::apply(InputArray _src, OutputArray _dst, Stream& s) + { + GpuMat src = _src.getGpuMat(); + + CV_Assert( src.type() == CV_8UC1 ); + + _dst.create( src.size(), src.type() ); + GpuMat dst = _dst.getGpuMat(); + + const int histSize = 256; + + ensureSizeIsEnough(tilesX_ * tilesY_, histSize, CV_8UC1, lut_); + + cudaStream_t stream = StreamAccessor::getStream(s); + + cv::Size tileSize; + GpuMat srcForLut; + + if (src.cols % tilesX_ == 0 && src.rows % tilesY_ == 0) + { + tileSize = cv::Size(src.cols / tilesX_, src.rows / tilesY_); + srcForLut = src; + } + else + { + cv::gpu::copyMakeBorder(src, srcExt_, 0, tilesY_ - (src.rows % tilesY_), 0, tilesX_ - (src.cols % tilesX_), cv::BORDER_REFLECT_101, cv::Scalar(), s); + + tileSize = cv::Size(srcExt_.cols / tilesX_, srcExt_.rows / tilesY_); + srcForLut = srcExt_; + } + + const int tileSizeTotal = tileSize.area(); + const float lutScale = static_cast(histSize - 1) / tileSizeTotal; + + int clipLimit = 0; + if (clipLimit_ > 0.0) + { + clipLimit = static_cast(clipLimit_ * tileSizeTotal / histSize); + clipLimit = std::max(clipLimit, 1); + } + + clahe::calcLut(srcForLut, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), clipLimit, lutScale, stream); + + clahe::transform(src, dst, lut_, tilesX_, tilesY_, make_int2(tileSize.width, tileSize.height), stream); + } + + void CLAHE_Impl::setClipLimit(double clipLimit) + { + clipLimit_ = clipLimit; + } + + double CLAHE_Impl::getClipLimit() const + { + return clipLimit_; + } + + void CLAHE_Impl::setTilesGridSize(cv::Size tileGridSize) + { + tilesX_ = tileGridSize.width; + tilesY_ = tileGridSize.height; + } + + cv::Size CLAHE_Impl::getTilesGridSize() const + { + return cv::Size(tilesX_, tilesY_); + } + + void CLAHE_Impl::collectGarbage() + { + srcExt_.release(); + lut_.release(); + } +} + +cv::Ptr cv::gpu::createCLAHE(double clipLimit, cv::Size tileGridSize) +{ + return new CLAHE_Impl(clipLimit, tileGridSize.width, tileGridSize.height); +} + #endif /* !defined (HAVE_CUDA) */ diff --git a/modules/gpu/test/test_imgproc.cpp b/modules/gpu/test/test_imgproc.cpp index 3341737415..925ca9d7ef 100644 --- a/modules/gpu/test/test_imgproc.cpp +++ b/modules/gpu/test/test_imgproc.cpp @@ -217,6 +217,50 @@ INSTANTIATE_TEST_CASE_P(GPU_ImgProc, EqualizeHist, testing::Combine( ALL_DEVICES, DIFFERENT_SIZES)); +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// CLAHE + +namespace +{ + IMPLEMENT_PARAM_CLASS(ClipLimit, double) +} + +PARAM_TEST_CASE(CLAHE, cv::gpu::DeviceInfo, cv::Size, ClipLimit) +{ + cv::gpu::DeviceInfo devInfo; + cv::Size size; + double clipLimit; + + virtual void SetUp() + { + devInfo = GET_PARAM(0); + size = GET_PARAM(1); + clipLimit = GET_PARAM(2); + + cv::gpu::setDevice(devInfo.deviceID()); + } +}; + +GPU_TEST_P(CLAHE, Accuracy) +{ + cv::Mat src = randomMat(size, CV_8UC1); + + cv::Ptr clahe = cv::gpu::createCLAHE(clipLimit); + cv::gpu::GpuMat dst; + clahe->apply(loadMat(src), dst); + + cv::Ptr clahe_gold = cv::createCLAHE(clipLimit); + cv::Mat dst_gold; + clahe_gold->apply(src, dst_gold); + + ASSERT_MAT_NEAR(dst_gold, dst, 1.0); +} + +INSTANTIATE_TEST_CASE_P(GPU_ImgProc, CLAHE, testing::Combine( + ALL_DEVICES, + DIFFERENT_SIZES, + testing::Values(0.0, 40.0))); + //////////////////////////////////////////////////////////////////////// // ColumnSum