refactor: PrefixSum

This commit is contained in:
marina.kolpakova 2012-11-14 14:47:00 +04:00
parent a30bbda3bd
commit 781c04324e

View File

@ -79,6 +79,39 @@ namespace icf {
} }
} }
template<typename Policy>
struct PrefixSum
{
__device static void apply(float& impact)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 300
#pragma unroll
// scan on shuffl functions
for (int i = 1; i < Policy::WARP; i *= 2)
{
const float n = __shfl_up(impact, i, Policy::WARP);
if (threadIdx.x >= i)
impact += n;
}
#else
__shared__ volatile float ptr[Policy::STA_X * Policy::STA_Y];
const int idx = threadIdx.y * Policy::STA_X + threadIdx.x;
ptr[idx] = impact;
if ( threadIdx.x >= 1) ptr [idx ] = (ptr [idx - 1] + ptr [idx]);
if ( threadIdx.x >= 2) ptr [idx ] = (ptr [idx - 2] + ptr [idx]);
if ( threadIdx.x >= 4) ptr [idx ] = (ptr [idx - 4] + ptr [idx]);
if ( threadIdx.x >= 8) ptr [idx ] = (ptr [idx - 8] + ptr [idx]);
if ( threadIdx.x >= 16) ptr [idx ] = (ptr [idx - 16] + ptr [idx]);
impact = ptr[idx];
#endif
}
};
texture<int, cudaTextureType2D, cudaReadModeElementType> thogluv; texture<int, cudaTextureType2D, cudaReadModeElementType> thogluv;
template<bool isUp> template<bool isUp>
@ -201,32 +234,9 @@ __device void CascadeInvoker<Policy>::detect(Detection* objects, const uint ndet
const int lShift = (next - 1) * 2 + (int)(sum >= threshold); const int lShift = (next - 1) * 2 + (int)(sum >= threshold);
float impact = leaves[(st + threadIdx.x) * 4 + lShift]; float impact = leaves[(st + threadIdx.x) * 4 + lShift];
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 300 PrefixSum<Policy>::apply(impact);
#pragma unroll
// scan on shuffl functions
for (int i = 1; i < Policy::WARP; i *= 2)
{
const float n = __shfl_up(impact, i, Policy::WARP);
if (threadIdx.x >= i)
impact += n;
}
#else
__shared__ volatile float ptr[Policy::STA_X * Policy::STA_Y];
const int idx = threadIdx.y * Policy::STA_X + threadIdx.x;
ptr[idx] = impact;
if ( threadIdx.x >= 1) ptr [idx ] = (ptr [idx - 1] + ptr [idx]);
if ( threadIdx.x >= 2) ptr [idx ] = (ptr [idx - 2] + ptr [idx]);
if ( threadIdx.x >= 4) ptr [idx ] = (ptr [idx - 4] + ptr [idx]);
if ( threadIdx.x >= 8) ptr [idx ] = (ptr [idx - 8] + ptr [idx]);
if ( threadIdx.x >= 16) ptr [idx ] = (ptr [idx - 16] + ptr [idx]);
impact = ptr[idx];
#endif
confidence += impact; confidence += impact;
if(__any((confidence <= stages[(st + threadIdx.x)]))) st += 2048; if(__any((confidence <= stages[(st + threadIdx.x)]))) st += 2048;
} }