mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
Merge pull request #16220 from YashasSamaga:cuda4dnn-roi-pooling-test_fix-optim
This commit is contained in:
commit
2ced568d34
@ -24,49 +24,75 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
|
|
||||||
namespace raw {
|
namespace raw {
|
||||||
|
|
||||||
template <class T>
|
template <class T, std::size_t CHANNELS_PER_ITER>
|
||||||
__global__ void roi_pooling(
|
__global__ void roi_pooling(
|
||||||
Span<T> output, size_type pooled_height, size_type pooled_width,
|
Span<T> output, size_type pooled_height, size_type pooled_width,
|
||||||
View<T> input, size_type in_height, size_type in_width,
|
View<T> input, size_type in_height, size_type in_width,
|
||||||
View<T> rois, size_type num_channels, T spatial_scale)
|
View<T> rois, size_type num_channels, float spatial_scale)
|
||||||
{
|
{
|
||||||
// input: [1, num_channels, in_height, in_width]
|
// input: [1, num_channels, in_height, in_width]
|
||||||
|
const auto in_image_size = in_height * in_width;
|
||||||
|
|
||||||
// rois: [num_rois, 5]
|
// rois: [num_rois, 5]
|
||||||
|
auto num_rois = rois.size() / 5;
|
||||||
|
|
||||||
// output: [num_rois, num_channels, pooled_height, pooled_width]
|
// output: [num_rois, num_channels, pooled_height, pooled_width]
|
||||||
const auto out_spatial_size = pooled_height * pooled_width;
|
const auto out_spatial_size = pooled_height * pooled_width;
|
||||||
const auto out_roi_size = num_channels * out_spatial_size;
|
const auto out_roi_size = num_channels * out_spatial_size;
|
||||||
|
|
||||||
/* every element in the output is mapped to a window in the input and each thread processes several windows */
|
/* we have to compute the output value for every combination of (roi, c, y, x) in the output
|
||||||
for (auto idx : grid_stride_range(output.size()))
|
*
|
||||||
{
|
* the computation involving (y, x) are identical for all non-spatial dimensions
|
||||||
const auto n = idx / out_roi_size;
|
* the computation and memory requests involving the roi are identical for remaining three axes
|
||||||
const auto c = (idx % out_roi_size) / out_spatial_size;
|
*
|
||||||
const auto y = (idx % out_spatial_size) / pooled_width;
|
* we process multiple channels every iteration to reuse the identical computation
|
||||||
const auto x = idx % pooled_width;
|
* and memory requests involved with the roi and spatial dimensions
|
||||||
|
*/
|
||||||
|
/*
|
||||||
|
* if we are processing `CHANNELS_PER_ITER` channels per iteration, we will need
|
||||||
|
* (num_channels / CHANNELS_PER_ITER) iterations per (roi, x, y)
|
||||||
|
*/
|
||||||
|
auto num_channel_iters_per_roi_xy = num_channels / CHANNELS_PER_ITER;
|
||||||
|
|
||||||
const index_type roi_offset = n * 5;
|
/* we need `num_channel_iters_per_roi_xy` iterations per (roi, x, y) and there are
|
||||||
|
* `num_rois` rois and `out_spatial_size` combinations of (x, y)
|
||||||
|
*/
|
||||||
|
auto iters_per_roi = num_channel_iters_per_roi_xy * out_spatial_size;
|
||||||
|
auto iters_required = num_rois * iters_per_roi;
|
||||||
|
|
||||||
|
for (auto iter : grid_stride_range(iters_required))
|
||||||
|
{
|
||||||
|
const index_type roi_no = iter / iters_per_roi;
|
||||||
|
const index_type c_start = ((iter % iters_per_roi) / out_spatial_size) * CHANNELS_PER_ITER;
|
||||||
|
|
||||||
|
/* note here that consecutive `iter` values will often have consecutive `x` values
|
||||||
|
* => stores into output will be coalesced across threads
|
||||||
|
*/
|
||||||
|
const index_type y = (iter % out_spatial_size) / pooled_width;
|
||||||
|
const index_type x = iter % pooled_width;
|
||||||
|
|
||||||
|
const index_type roi_offset = roi_no * 5;
|
||||||
|
|
||||||
using device::round;
|
using device::round;
|
||||||
const index_type batch_id = rois[roi_offset + 0];
|
const index_type batch_id = rois[roi_offset + 0];
|
||||||
const index_type x_start_roi = round(rois[roi_offset + 1] * spatial_scale);
|
const index_type x_start_roi = round(static_cast<float>(rois[roi_offset + 1]) * spatial_scale);
|
||||||
const index_type y_start_roi = round(rois[roi_offset + 2] * spatial_scale);
|
const index_type y_start_roi = round(static_cast<float>(rois[roi_offset + 2]) * spatial_scale);
|
||||||
const index_type x_end_roi = round(rois[roi_offset + 3] * spatial_scale);
|
const index_type x_end_roi = round(static_cast<float>(rois[roi_offset + 3]) * spatial_scale);
|
||||||
const index_type y_end_roi = round(rois[roi_offset + 4] * spatial_scale);
|
const index_type y_end_roi = round(static_cast<float>(rois[roi_offset + 4]) * spatial_scale);
|
||||||
|
|
||||||
using device::max;
|
using device::max;
|
||||||
const auto roi_width = max<index_type>(x_end_roi - x_start_roi + 1, 1);
|
const auto roi_width = max<index_type>(x_end_roi - x_start_roi + 1, 1);
|
||||||
const auto roi_height = max<index_type>(y_end_roi - y_start_roi + 1, 1);
|
const auto roi_height = max<index_type>(y_end_roi - y_start_roi + 1, 1);
|
||||||
|
|
||||||
const auto roi_width_ratio = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
const auto roi_width_ratio = static_cast<float>(roi_width) / pooled_width;
|
||||||
const auto roi_height_ratio = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
const auto roi_height_ratio = static_cast<float>(roi_height) / pooled_height;
|
||||||
|
|
||||||
auto x_start = x_start_roi + static_cast<index_type>(static_cast<T>(x) * roi_width_ratio);
|
auto x_start = x_start_roi + static_cast<index_type>(x * roi_width_ratio);
|
||||||
auto y_start = y_start_roi + static_cast<index_type>(static_cast<T>(y) * roi_height_ratio);
|
auto y_start = y_start_roi + static_cast<index_type>(y * roi_height_ratio);
|
||||||
|
|
||||||
using device::ceil;
|
using device::ceil;
|
||||||
auto x_end = x_start_roi + static_cast<index_type>(ceil(static_cast<T>(x + 1) * roi_width_ratio));
|
auto x_end = x_start_roi + static_cast<index_type>(ceil((x + 1) * roi_width_ratio));
|
||||||
auto y_end = y_start_roi + static_cast<index_type>(ceil(static_cast<T>(y + 1) * roi_height_ratio));
|
auto y_end = y_start_roi + static_cast<index_type>(ceil((y + 1) * roi_height_ratio));
|
||||||
|
|
||||||
using device::max;
|
using device::max;
|
||||||
x_start = max<index_type>(x_start, 0);
|
x_start = max<index_type>(x_start, 0);
|
||||||
@ -76,29 +102,48 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
x_end = min<index_type>(x_end, in_width);
|
x_end = min<index_type>(x_end, in_width);
|
||||||
y_end = min<index_type>(y_end, in_height);
|
y_end = min<index_type>(y_end, in_height);
|
||||||
|
|
||||||
/* We have to set the output to zero if (x_start >= x_end) or (y_start >= y_end). If either
|
index_type in_offset = (batch_id * num_channels + c_start) * in_height * in_width;
|
||||||
* condition is true, the loops below won't execute even a single iteration. Hence, by setting
|
index_type out_idx = roi_no * out_roi_size + c_start * out_spatial_size + y * pooled_width + x;
|
||||||
* `max_val` to zero in this case, we can combine it with the `else` code.
|
|
||||||
*/
|
|
||||||
T max_val = (x_start >= x_end || y_start >= y_end) ? T(0) : device::numeric_limits<T>::lowest();
|
|
||||||
|
|
||||||
const index_type in_offset = (batch_id * num_channels + c) * in_height * in_width;
|
for (int i = 0; i < CHANNELS_PER_ITER; i++)
|
||||||
for (auto iy = y_start; iy < y_end; iy++)
|
|
||||||
{
|
{
|
||||||
for (auto ix = x_start; ix < x_end; ix++)
|
/* We have to set the output to zero if (x_start >= x_end) or (y_start >= y_end). If either
|
||||||
{
|
* condition is true, the loops below won't execute even a single iteration. Hence, by setting
|
||||||
const auto in_idx = in_offset + iy * in_width + ix;
|
* `max_val` to zero in this case, we can combine it with the `else` code.
|
||||||
max_val = max(max_val, input[in_idx]);
|
*/
|
||||||
}
|
T max_val = (x_start >= x_end || y_start >= y_end) ? T(0) : device::numeric_limits<T>::lowest();
|
||||||
}
|
|
||||||
|
|
||||||
output[idx] = max_val;
|
for (auto iy = y_start; iy < y_end; iy++)
|
||||||
|
{
|
||||||
|
const auto in_idx = in_offset + iy * in_width;
|
||||||
|
for (auto ix = x_start; ix < x_end; ix++)
|
||||||
|
{
|
||||||
|
max_val = max(max_val, input[in_idx + ix]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output[out_idx] = max_val;
|
||||||
|
|
||||||
|
in_offset += in_image_size;
|
||||||
|
out_idx += out_spatial_size;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T, std::size_t CHANNELS_PER_ITER> static
|
||||||
|
void launch_multichannel_roi_pooling(const Stream& stream,
|
||||||
|
Span<T> output, size_type pooled_height, size_type pooled_width,
|
||||||
|
View<T> input, size_type in_height, size_type in_width,
|
||||||
|
View<T> rois, size_type num_channels, float spatial_scale)
|
||||||
|
{
|
||||||
|
auto kernel = raw::roi_pooling<T, CHANNELS_PER_ITER>;
|
||||||
|
auto policy = make_policy(kernel, output.size() / CHANNELS_PER_ITER, 0, stream);
|
||||||
|
launch_kernel(kernel, policy, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
void roi_pooling(const Stream& stream, TensorSpan<T> output, TensorView<T> input, View<T> rois, T spatial_scale)
|
void roi_pooling(const Stream& stream, TensorSpan<T> output, TensorView<T> input, View<T> rois, float spatial_scale)
|
||||||
{
|
{
|
||||||
CV_Assert(input.get_axis_size(1) == output.get_axis_size(1));
|
CV_Assert(input.get_axis_size(1) == output.get_axis_size(1));
|
||||||
|
|
||||||
@ -110,13 +155,25 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|||||||
size_type in_height = input.get_axis_size(2);
|
size_type in_height = input.get_axis_size(2);
|
||||||
size_type in_width = input.get_axis_size(3);
|
size_type in_width = input.get_axis_size(3);
|
||||||
|
|
||||||
auto kernel = raw::roi_pooling<T>;
|
if (num_channels % 64 == 0) {
|
||||||
auto policy = make_policy(kernel, output.size(), 0, stream);
|
launch_multichannel_roi_pooling<T, 64>(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
launch_kernel(kernel, policy, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
} else if (num_channels % 32 == 0) {
|
||||||
|
launch_multichannel_roi_pooling<T, 32>(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
|
} else if (num_channels % 16 == 0) {
|
||||||
|
launch_multichannel_roi_pooling<T, 16>(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
|
} else if (num_channels % 8 == 0) {
|
||||||
|
launch_multichannel_roi_pooling<T, 8>(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
|
} else if (num_channels % 4 == 0) {
|
||||||
|
launch_multichannel_roi_pooling<T, 4>(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
|
} else if (num_channels % 2 == 0) {
|
||||||
|
launch_multichannel_roi_pooling<T, 2>(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
|
} else {
|
||||||
|
launch_multichannel_roi_pooling<T, 1>(stream, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
||||||
template void roi_pooling(const Stream& stream, TensorSpan<__half> output, TensorView<__half> input, View<__half> rois, __half spatial_scale);
|
template void roi_pooling(const Stream& stream, TensorSpan<__half> output, TensorView<__half> input, View<__half> rois, float spatial_scale);
|
||||||
#endif
|
#endif
|
||||||
template void roi_pooling(const Stream& stream, TensorSpan<float> output, TensorView<float> input, View<float> rois, float spatial_scale);
|
template void roi_pooling(const Stream& stream, TensorSpan<float> output, TensorView<float> input, View<float> rois, float spatial_scale);
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
void roi_pooling(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, csl::View<T> rois, T spatial_scale);
|
void roi_pooling(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, csl::View<T> rois, float spatial_scale);
|
||||||
|
|
||||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||||
|
|
||||||
|
@ -600,6 +600,11 @@ TEST_P(Test_Caffe_layers, ROIPooling_Accuracy)
|
|||||||
|
|
||||||
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 1e-3 : 1e-5;
|
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 1e-3 : 1e-5;
|
||||||
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 1e-3 : 1e-4;
|
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 1e-3 : 1e-4;
|
||||||
|
if (target == DNN_TARGET_CUDA_FP16)
|
||||||
|
{
|
||||||
|
l1 = 2e-4;
|
||||||
|
lInf = 9e-4;
|
||||||
|
}
|
||||||
normAssert(out, ref, "", l1, lInf);
|
normAssert(out, ref, "", l1, lInf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user