mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 06:03:15 +08:00
Merge pull request #17301 from YashasSamaga:cuda4dnn-detection-output
This commit is contained in:
commit
e421233a1d
39
modules/dnn/src/cuda/bbox_utils.hpp
Normal file
39
modules/dnn/src/cuda/bbox_utils.hpp
Normal file
@ -0,0 +1,39 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA_BBOX_UTILS_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA_BBOX_UTILS_HPP
|
||||
|
||||
#include "math.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
struct BoundingBox
|
||||
{
|
||||
float xmin, ymin, xmax, ymax;
|
||||
};
|
||||
|
||||
template <bool NORMALIZED_BBOX>
|
||||
__device__ __forceinline__ float compute_bbox_size(BoundingBox bbox)
|
||||
{
|
||||
float width = bbox.xmax - bbox.xmin;
|
||||
float height = bbox.ymax - bbox.ymin;
|
||||
if (width < 0 || height < 0)
|
||||
return 0.0;
|
||||
|
||||
if (!NORMALIZED_BBOX)
|
||||
{
|
||||
width += 1;
|
||||
height += 1;
|
||||
}
|
||||
|
||||
using csl::device::mul_ftz;
|
||||
return mul_ftz(width, height);
|
||||
}
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA_BBOX_UTILS_HPP */
|
71
modules/dnn/src/cuda/block_stride_range.hpp
Normal file
71
modules/dnn/src/cuda/block_stride_range.hpp
Normal file
@ -0,0 +1,71 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA_BLOCK_STRIDE_RANGE_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA_BLOCK_STRIDE_RANGE_HPP
|
||||
|
||||
#include "types.hpp"
|
||||
#include "index_helpers.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
|
||||
|
||||
template <int dim, int BLOCK_SIZE = 0, class index_type = device::index_type, class size_type = device::size_type>
|
||||
class block_stride_range_generic {
|
||||
public:
|
||||
__device__ block_stride_range_generic(index_type to_) : from(0), to(to_) { }
|
||||
__device__ block_stride_range_generic(index_type from_, index_type to_) : from(from_), to(to_) { }
|
||||
|
||||
class iterator
|
||||
{
|
||||
public:
|
||||
__device__ iterator(index_type pos_) : pos(pos_) {}
|
||||
|
||||
/* these iterators return the index when dereferenced; this allows us to loop
|
||||
* through the indices using a range based for loop
|
||||
*/
|
||||
__device__ index_type operator*() const { return pos; }
|
||||
|
||||
__device__ iterator& operator++() {
|
||||
const index_type block_size = BLOCK_SIZE == 0 ? getBlockDim<dim>() : BLOCK_SIZE;
|
||||
pos += block_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ bool operator!=(const iterator& other) const {
|
||||
/* NOTE HACK
|
||||
* 'pos' can move in large steps (see operator++)
|
||||
* expansion of range for loop uses != as the loop conditioion
|
||||
* => operator!= must return false if 'pos' crosses the end
|
||||
*/
|
||||
return pos < other.pos;
|
||||
}
|
||||
|
||||
private:
|
||||
index_type pos;
|
||||
};
|
||||
|
||||
__device__ iterator begin() const {
|
||||
return iterator(from + getThreadIdx<dim>());
|
||||
}
|
||||
|
||||
__device__ iterator end() const {
|
||||
return iterator(to);
|
||||
}
|
||||
|
||||
private:
|
||||
index_type from, to;
|
||||
};
|
||||
|
||||
using block_stride_range_x = block_stride_range_generic<0>;
|
||||
using block_stride_range_y = block_stride_range_generic<1>;
|
||||
using block_stride_range_z = block_stride_range_generic<2>;
|
||||
|
||||
template <size_type BLOCK_SIZE = 0>
|
||||
using block_stride_range = block_stride_range_generic<0, BLOCK_SIZE>;
|
||||
|
||||
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA_BLOCK_STRIDE_RANGE_HPP */
|
893
modules/dnn/src/cuda/detection_output.cu
Normal file
893
modules/dnn/src/cuda/detection_output.cu
Normal file
@ -0,0 +1,893 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "math.hpp"
|
||||
#include "bbox_utils.hpp"
|
||||
#include "grid_stride_range.hpp"
|
||||
#include "block_stride_range.hpp"
|
||||
#include "execution.hpp"
|
||||
#include "vector_traits.hpp"
|
||||
#include "memory.hpp"
|
||||
|
||||
#include "../cuda4dnn/csl/stream.hpp"
|
||||
#include "../cuda4dnn/csl/span.hpp"
|
||||
#include "../cuda4dnn/csl/tensor.hpp"
|
||||
|
||||
using namespace cv::dnn::cuda4dnn::csl;
|
||||
using namespace cv::dnn::cuda4dnn::csl::device;
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
namespace raw {
|
||||
|
||||
template <class T, bool SHARE_LOCATION, bool VARIANCE_ENCODED_IN_TARGET, bool CORNER_TRUE_CENTER_FALSE, bool CLIP_BBOX>
|
||||
__global__ void decode_bbox(Span<T> decoded_bboxes, View<T> locations, View<T> priors,
|
||||
bool transpose_location, bool normalized_bbox,
|
||||
size_type num_loc_classes, index_type background_class_id,
|
||||
float clip_width, float clip_height)
|
||||
{
|
||||
// decoded_bboxes: [batch_size, num_priors, num_loc_classes, 4]
|
||||
// locations: [batch_size, num_priors, num_loc_classes, 4]
|
||||
// priors: [1, C, num_priors, 4]
|
||||
// C = 2 if !VARIANCE_ENCODED_IN_TARGET; otherwise, 1
|
||||
|
||||
/* 4 bbox values + 4 variance values per prior */
|
||||
constexpr int PRIOR_BOX_SIZE = VARIANCE_ENCODED_IN_TARGET ? 4 : 8;
|
||||
const size_type num_priors = priors.size() / PRIOR_BOX_SIZE;
|
||||
|
||||
using vector_type = get_vector_type_t<T, 4>;
|
||||
auto locations_vPtr = vector_type::get_pointer(locations.data());
|
||||
auto priors_vPtr = vector_type::get_pointer(priors.data());
|
||||
auto decoded_bboxes_vPtr = vector_type::get_pointer(decoded_bboxes.data());
|
||||
|
||||
const auto boxes_per_batch = num_priors * num_loc_classes;
|
||||
for (auto idx : grid_stride_range(decoded_bboxes.size() / 4))
|
||||
{
|
||||
index_type p;
|
||||
index_type c;
|
||||
|
||||
if (SHARE_LOCATION)
|
||||
{
|
||||
// locations are shared across all classes => num_loc_classes = 1
|
||||
p = idx % boxes_per_batch;
|
||||
c = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
p = (idx % boxes_per_batch) / num_loc_classes;
|
||||
c = idx % num_loc_classes;
|
||||
}
|
||||
|
||||
if (!SHARE_LOCATION && c == background_class_id)
|
||||
continue;
|
||||
|
||||
BoundingBox bbox;
|
||||
{
|
||||
vector_type location;
|
||||
v_load(location, locations_vPtr[idx]);
|
||||
|
||||
if (transpose_location)
|
||||
{
|
||||
bbox.ymin = location.data[0];
|
||||
bbox.xmin = location.data[1];
|
||||
bbox.ymax = location.data[2];
|
||||
bbox.xmax = location.data[3];
|
||||
}
|
||||
else
|
||||
{
|
||||
bbox.xmin = location.data[0];
|
||||
bbox.ymin = location.data[1];
|
||||
bbox.xmax = location.data[2];
|
||||
bbox.ymax = location.data[3];
|
||||
}
|
||||
}
|
||||
|
||||
if (!VARIANCE_ENCODED_IN_TARGET)
|
||||
{
|
||||
vector_type prior_variance;
|
||||
v_load_ldg(prior_variance, priors_vPtr[num_priors + p]);
|
||||
|
||||
bbox.xmin *= static_cast<float>(prior_variance.data[0]);
|
||||
bbox.ymin *= static_cast<float>(prior_variance.data[1]);
|
||||
bbox.xmax *= static_cast<float>(prior_variance.data[2]);
|
||||
bbox.ymax *= static_cast<float>(prior_variance.data[3]);
|
||||
}
|
||||
|
||||
BoundingBox prior;
|
||||
{
|
||||
vector_type prior_box;
|
||||
v_load_ldg(prior_box, priors_vPtr[p]);
|
||||
|
||||
prior.xmin = prior_box.data[0];
|
||||
prior.ymin = prior_box.data[1];
|
||||
prior.xmax = prior_box.data[2];
|
||||
prior.ymax = prior_box.data[3];
|
||||
}
|
||||
|
||||
BoundingBox decoded_bbox;
|
||||
if (CORNER_TRUE_CENTER_FALSE)
|
||||
{
|
||||
decoded_bbox.xmin = prior.xmin + bbox.xmin;
|
||||
decoded_bbox.ymin = prior.ymin + bbox.ymin;
|
||||
decoded_bbox.xmax = prior.xmax + bbox.xmax;
|
||||
decoded_bbox.ymax = prior.ymax + bbox.ymax;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto prior_width = prior.xmax - prior.xmin;
|
||||
auto prior_height = prior.ymax - prior.ymin;
|
||||
if (!normalized_bbox)
|
||||
{
|
||||
prior_width += 1;
|
||||
prior_height += 1;
|
||||
}
|
||||
|
||||
auto prior_center_x = prior.xmin + prior_width * 0.5f;
|
||||
auto prior_center_y = prior.ymin + prior_height * 0.5f;
|
||||
|
||||
auto decode_bbox_center_x = bbox.xmin * prior_width + prior_center_x;
|
||||
auto decode_bbox_center_y = bbox.ymin * prior_height + prior_center_y;
|
||||
|
||||
using device::exp;
|
||||
float decode_bbox_width = exp(bbox.xmax) * prior_width;
|
||||
float decode_bbox_height = exp(bbox.ymax) * prior_height;
|
||||
|
||||
decoded_bbox.xmin = decode_bbox_center_x - decode_bbox_width * 0.5f;
|
||||
decoded_bbox.ymin = decode_bbox_center_y - decode_bbox_height * 0.5f;
|
||||
decoded_bbox.xmax = decode_bbox_center_x + decode_bbox_width * 0.5f;
|
||||
decoded_bbox.ymax = decode_bbox_center_y + decode_bbox_height * 0.5f;
|
||||
}
|
||||
|
||||
vector_type decoded_bbox_vec;
|
||||
if (CLIP_BBOX)
|
||||
{
|
||||
decoded_bbox_vec.data[0] = clamp(decoded_bbox.xmin, 0.0f, clip_width);
|
||||
decoded_bbox_vec.data[1] = clamp(decoded_bbox.ymin, 0.0f, clip_height);
|
||||
decoded_bbox_vec.data[2] = clamp(decoded_bbox.xmax, 0.0f, clip_width);
|
||||
decoded_bbox_vec.data[3] = clamp(decoded_bbox.ymax, 0.0f, clip_height);
|
||||
}
|
||||
else
|
||||
{
|
||||
decoded_bbox_vec.data[0] = decoded_bbox.xmin;
|
||||
decoded_bbox_vec.data[1] = decoded_bbox.ymin;
|
||||
decoded_bbox_vec.data[2] = decoded_bbox.xmax;
|
||||
decoded_bbox_vec.data[3] = decoded_bbox.ymax;
|
||||
}
|
||||
|
||||
v_store(decoded_bboxes_vPtr[idx], decoded_bbox_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, int BINS, int BLOCK_SIZE>
|
||||
__launch_bounds__(BLOCK_SIZE)
|
||||
__global__ void findTopK(Span<int> indices_, Span<int> count_, View<T> scores_, float threshold, size_type classwise_topK, size_type num_classes, size_type num_priors, index_type background_class_id)
|
||||
{
|
||||
/* We need to sort boxes based on their confidence scores. The confidence scores fall in
|
||||
* the range [0.0, 1.0]. We break the range into bins and perform count sort. This is an
|
||||
* approximate algorithm.
|
||||
*
|
||||
* Each block handles a particular class of a particular batch item.
|
||||
*/
|
||||
const auto c = blockIdx.x;
|
||||
const auto b = blockIdx.y;
|
||||
|
||||
if (c == background_class_id)
|
||||
return;
|
||||
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// scores: [batch_size, num_classes, num_priors]
|
||||
|
||||
auto count = count_.data() + b * num_classes + c;
|
||||
auto scores = scores_.data() + (b * num_classes + c) * num_priors;
|
||||
auto indices = indices_.data() + (b * num_classes + c) * classwise_topK;
|
||||
|
||||
/* We do not require a large number of bins to find the top K confidence scores. We will use
|
||||
* a reasonable number of bins which will fit in the shared memory.
|
||||
*
|
||||
* Note that smaller scores will have a smaller index, i.e. the `bins` are ordered in
|
||||
* ascending order.
|
||||
*/
|
||||
|
||||
__shared__ int bins[BINS];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < BINS / BLOCK_SIZE; unroll++)
|
||||
bins[unroll * BLOCK_SIZE + threadIdx.x] = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (auto i : block_stride_range<BLOCK_SIZE>(num_priors))
|
||||
{
|
||||
const float confidence = load_ldg(scores[i]);
|
||||
if (confidence > threshold)
|
||||
{
|
||||
using device::fast_divide_ftz;
|
||||
auto conf_scaled = fast_divide_ftz(confidence - threshold, 1 - threshold);
|
||||
|
||||
using device::clamp;
|
||||
int bin_index = conf_scaled * BINS;
|
||||
|
||||
/* We store counts of confidence scores in the bins. Our ultimate goal is to store the indices
|
||||
* of the `classwise_topK` confidence values in the `indices` array.
|
||||
*
|
||||
* We use a little trick to parallelize the process of filling up the `indices` array.
|
||||
* We want every thread in the block to participate in the process. To do so, we want the
|
||||
* bins array to be shifted by one place to the left. We will be computing the suffix sum
|
||||
* of the bins array later. Details and reasons for doing so will be explained later.
|
||||
*/
|
||||
bin_index = clamp<int>(bin_index, 0, BINS - 1) - 1; // shift left by one
|
||||
|
||||
if (bin_index >= 0)
|
||||
atomicAdd(&bins[bin_index], 1);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
constexpr int WARP_SIZE = 32; /* must be equal to warpSize */
|
||||
// FORWARD_COMPATIBILITY_TAG: WARP_SIZE_DEPENDENT_CODE
|
||||
|
||||
if (threadIdx.x < WARP_SIZE)
|
||||
{
|
||||
/* We can compute suffix sum of an array in groups of N numbers.
|
||||
* Let N be 4 for this example.
|
||||
*
|
||||
* 1) Last 4 numbers
|
||||
* 1 2 3 4 | 5 6 7 8 | 9 10 11 12
|
||||
* group suffix sum: 42 33 23 12
|
||||
*
|
||||
* 2) Middle 4 numbers
|
||||
* 1 2 3 4 | 5 6 7 8 | 9 10 11 12
|
||||
* group suffix sum: | 26 21 15 8 |
|
||||
*
|
||||
* We add `42` (first element in the previous group) to each element to get:
|
||||
*
|
||||
* 1 2 3 4 | 5 6 7 8 | 9 10 11 12
|
||||
* | 68 63 57 50 | 42 33 23 12
|
||||
* 3) First 4 numbers
|
||||
*
|
||||
* 1 2 3 4 | 5 6 7 8 | 9 10 11 12
|
||||
* group suffix sum: 10 9 7 4 |
|
||||
*
|
||||
* We add `68` (first element in the previous group) to each element to get:
|
||||
*
|
||||
* 1 2 3 4 | 5 6 7 8 | 9 10 11 12
|
||||
* group suffix sum: 78 77 75 72 | 68 63 57 50 | 42 33 23 12
|
||||
*
|
||||
* What we are left with now is the suffix sum of the entire array.
|
||||
*
|
||||
* We use the aforementioned logic in the code below but work in groups of `warpSize`.
|
||||
*/
|
||||
|
||||
/* We calculate suffix sums WARP_SIZE elements at a time starting from the right end.
|
||||
* Hence, we will need BINS / WARP_SIZE number of iterations.
|
||||
*
|
||||
* Each iteration uses shuffle instructions to exchange data between threads. Shuffle
|
||||
* instructions cannot be used in warp-divergent code. If the bins are a multiple of
|
||||
* the warpSize, all the threads in the warp will participate.
|
||||
*/
|
||||
static_assert(BINS % WARP_SIZE == 0, "number of bins must be a multiple of warp size");
|
||||
|
||||
const int thread_id = threadIdx.x;
|
||||
const int inverse_lane_id = WARP_SIZE - thread_id - 1;
|
||||
|
||||
int previous_group_first_element = 0;
|
||||
for (int iter = BINS / WARP_SIZE - 1; iter >= 0; iter--)
|
||||
{
|
||||
const index_type idx = iter * WARP_SIZE + thread_id;
|
||||
auto value = bins[idx];
|
||||
|
||||
for (int i = 1; i < WARP_SIZE; i *= 2)
|
||||
{
|
||||
auto n = __shfl_down_sync(0xFFFFFFFF, value, i);
|
||||
if (inverse_lane_id >= i)
|
||||
value += n;
|
||||
}
|
||||
|
||||
value += previous_group_first_element;
|
||||
bins[idx] = value;
|
||||
|
||||
previous_group_first_element = __shfl_sync(0xFFFFFFFF, value, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*count = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (auto i : block_stride_range<BLOCK_SIZE>(num_priors))
|
||||
{
|
||||
const float confidence = load_ldg(scores[i]);
|
||||
if (confidence > threshold)
|
||||
{
|
||||
using device::fast_divide_ftz;
|
||||
auto conf_scaled = fast_divide_ftz(confidence - threshold, 1 - threshold);
|
||||
|
||||
int bin_index = conf_scaled * BINS;
|
||||
bin_index = clamp<int>(bin_index, 0, BINS - 1);
|
||||
|
||||
/* This bounding box is eligible to be selected unless it does not fall in
|
||||
* the `classwise_topK`. If it did, we would have to compute the location where it needs
|
||||
* to be stored.
|
||||
*
|
||||
* Suppose we had just 4 bins and say the following were the counts:
|
||||
* BIN0 2
|
||||
* BIN1 1
|
||||
* BIN2 3
|
||||
* BIN3 0 (last bin is always zero as we shift left by one while populating the bins)
|
||||
*
|
||||
* We will try our best to store the boxes in a sorted order in the `indices` array.
|
||||
* This requires that the boxes in later bins (higher confidence scores) must be
|
||||
* stored earlier.
|
||||
*
|
||||
* We compute the suffix sum of the array. This gives us:
|
||||
* BIN0 6
|
||||
* BIN1 4
|
||||
* BIN2 3
|
||||
* BIN3 0
|
||||
*
|
||||
* The bins now give us the location in the `indices` array from which the indices of the
|
||||
* scores corresponding to that bin would be stored. We atomically increment the bin count
|
||||
* everytime we store a box corresponding to that bin. Therefore, the value in the bins
|
||||
* gives the index in the `indices` array where the next box corresponding to that bin must
|
||||
* be put.
|
||||
*/
|
||||
|
||||
const index_type idx = atomicAdd(&bins[bin_index], 1);
|
||||
if (idx < classwise_topK)
|
||||
{
|
||||
indices[idx] = i;
|
||||
atomicAdd(&count[0], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void box_collect(Span<T> collected_bboxes_, View<T> decoded_bboxes_, View<int> indices_, View<int> count_, bool share_location, size_type num_priors, size_type num_classes, size_type classwise_topK, index_type background_class_id)
|
||||
{
|
||||
const index_type c = blockIdx.x;
|
||||
if (c == background_class_id)
|
||||
return;
|
||||
|
||||
const index_type b = blockIdx.y;
|
||||
|
||||
// collected_bboxes: [batch_size, num_classes, classwise_topK, 4]
|
||||
// decoded_bboxes: [batch_size, num_priors, num_loc_classes, 4]
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
|
||||
const auto num_loc_classes = share_location ? 1 : num_classes;
|
||||
|
||||
auto collected_bboxes = collected_bboxes_.data() + (b * num_classes + c) * classwise_topK * 4;
|
||||
auto decoded_bboxes = decoded_bboxes_.data() + b * num_priors * num_loc_classes * 4;
|
||||
auto indices = indices_.data() + (b * num_classes + c) * classwise_topK;
|
||||
auto count = count_.data() + b * num_classes + c;
|
||||
|
||||
const auto boxes = load_ldg(&count[0]);
|
||||
if (boxes == 0)
|
||||
return;
|
||||
|
||||
using vector_type = get_vector_type_t<T, 4>;
|
||||
auto decoded_bboxes_vPtr = vector_type::get_pointer(decoded_bboxes);
|
||||
auto collected_bboxes_vPtr = vector_type::get_pointer(collected_bboxes);
|
||||
|
||||
for (auto i : block_stride_range<>(boxes))
|
||||
{
|
||||
const auto prior_id = indices[i];
|
||||
const index_type idx = share_location ? prior_id : (prior_id * num_classes + c);
|
||||
|
||||
vector_type box;
|
||||
v_load(box, decoded_bboxes_vPtr[idx]);
|
||||
v_store(collected_bboxes_vPtr[i], box);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, bool NORMALIZED_BBOX>
|
||||
__global__ void blockwise_class_nms(Span<int> indices_, Span<int> count_, View<T> collected_bboxes_, size_type num_classes, size_type classwise_topK, index_type background_class_id, float nms_threshold)
|
||||
{
|
||||
const index_type b = blockIdx.x / num_classes;
|
||||
const index_type c = blockIdx.x % num_classes;
|
||||
if (c == background_class_id)
|
||||
return;
|
||||
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// collected_bboxes: [batch_size, num_classes, classwise_topK, 4]
|
||||
|
||||
auto indices = indices_.data() + (b * num_classes + c) * classwise_topK;
|
||||
auto count = count_.data() + b * num_classes + c;
|
||||
auto collected_bboxes = collected_bboxes_.data() + (b * num_classes + c) * classwise_topK * 4;
|
||||
|
||||
const auto boxes = count[0];
|
||||
if (boxes == 0)
|
||||
return;
|
||||
|
||||
using vector_type = get_vector_type_t<T, 4>;
|
||||
auto collected_bboxes_vPtr = vector_type::get_pointer(collected_bboxes);
|
||||
|
||||
for (int i = 0; i < boxes; i++)
|
||||
{
|
||||
auto prior_id = indices[i];
|
||||
if (prior_id != -1)
|
||||
{
|
||||
BoundingBox bbox1;
|
||||
{
|
||||
vector_type box;
|
||||
v_load(box, collected_bboxes_vPtr[i]);
|
||||
|
||||
bbox1.xmin = box.data[0];
|
||||
bbox1.ymin = box.data[1];
|
||||
bbox1.xmax = box.data[2];
|
||||
bbox1.ymax = box.data[3];
|
||||
}
|
||||
|
||||
for (auto j : block_stride_range<>(i + 1, boxes))
|
||||
{
|
||||
prior_id = indices[j];
|
||||
if (prior_id == -1)
|
||||
continue;
|
||||
|
||||
BoundingBox bbox2;
|
||||
{
|
||||
vector_type box;
|
||||
v_load_ldg(box, collected_bboxes_vPtr[j]);
|
||||
|
||||
bbox2.xmin = box.data[0];
|
||||
bbox2.ymin = box.data[1];
|
||||
bbox2.xmax = box.data[2];
|
||||
bbox2.ymax = box.data[3];
|
||||
}
|
||||
|
||||
using device::min;
|
||||
using device::max;
|
||||
|
||||
BoundingBox intersect_bbox;
|
||||
intersect_bbox.xmin = max(bbox1.xmin, bbox2.xmin);
|
||||
intersect_bbox.ymin = max(bbox1.ymin, bbox2.ymin);
|
||||
intersect_bbox.xmax = min(bbox1.xmax, bbox2.xmax);
|
||||
intersect_bbox.ymax = min(bbox1.ymax, bbox2.ymax);
|
||||
|
||||
float intersect_size = compute_bbox_size<NORMALIZED_BBOX>(intersect_bbox);
|
||||
float bbox1_size = compute_bbox_size<NORMALIZED_BBOX>(bbox1);
|
||||
float bbox2_size = compute_bbox_size<NORMALIZED_BBOX>(bbox2);
|
||||
|
||||
using device::fast_divide_ftz;
|
||||
float iou = fast_divide_ftz(intersect_size, bbox1_size + bbox2_size - intersect_size);
|
||||
if (iou > nms_threshold)
|
||||
indices[j] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
count[0] = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (auto i : block_stride_range<>(boxes))
|
||||
{
|
||||
auto prior_id = indices[i];
|
||||
if(prior_id != -1)
|
||||
{
|
||||
const index_type idx = atomicAdd(&count[0], 1);
|
||||
indices[idx] = prior_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, std::size_t BINS, int BLOCK_SIZE>
|
||||
__launch_bounds__(BLOCK_SIZE)
|
||||
__global__ void nms_collect(
|
||||
Span<int> kept_indices, Span<int> kept_count, View<int> indices_, View<int> count, View<T> scores_, float threshold,
|
||||
size_type num_classes, size_type num_priors, size_type classwise_topK, size_type keepTopK, index_type background_class_id)
|
||||
{
|
||||
// sorting algorithm is documented in detail in findTopK kernel comments
|
||||
// no explanations are provided here
|
||||
|
||||
// kept_indices: [batch_size, keepTopK]
|
||||
// kept_count: [batch_size]
|
||||
|
||||
const auto b = blockIdx.x;
|
||||
|
||||
__shared__ int bins[BINS];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < BINS / BLOCK_SIZE; unroll++)
|
||||
bins[unroll * BLOCK_SIZE + threadIdx.x] = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int c = 0; c < num_classes; c++)
|
||||
{
|
||||
if (c == background_class_id)
|
||||
continue;
|
||||
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// scores: [batch_size, num_classes, num_priors]
|
||||
|
||||
const auto indices = indices_.data() + (b * num_classes + c) * classwise_topK;
|
||||
const auto scores = scores_.data() + (b * num_classes + c) * num_priors;
|
||||
|
||||
auto boxes = count[b * num_classes + c];
|
||||
|
||||
for (auto i : block_stride_range<BLOCK_SIZE>(boxes))
|
||||
{
|
||||
auto prior_id = indices[i];
|
||||
const float confidence = load_ldg(scores[prior_id]);
|
||||
if (confidence > threshold)
|
||||
{
|
||||
using device::fast_divide_ftz;
|
||||
auto conf_scaled = fast_divide_ftz(confidence - threshold, 1 - threshold);
|
||||
|
||||
using device::clamp;
|
||||
int bin_index = conf_scaled * BINS;
|
||||
bin_index = clamp<int>(bin_index, 0, BINS - 1) - 1; // shift left by one
|
||||
|
||||
if (bin_index >= 0)
|
||||
atomicAdd(&bins[bin_index], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
constexpr int WARP_SIZE = 32; /* must be equal to warpSize */
|
||||
// FORWARD_COMPATIBILITY_TAG: WARP_SIZE_DEPENDENT_CODE
|
||||
|
||||
if (threadIdx.x < WARP_SIZE)
|
||||
{
|
||||
static_assert(BINS % WARP_SIZE == 0, "number of bins must be a multiple of warp size");
|
||||
|
||||
const int thread_id = threadIdx.x;
|
||||
const int inverse_lane_id = WARP_SIZE - thread_id - 1;
|
||||
|
||||
int previous_group_first_element = 0;
|
||||
for (int iter = BINS / WARP_SIZE - 1; iter >= 0; iter--)
|
||||
{
|
||||
const index_type idx = iter * WARP_SIZE + thread_id;
|
||||
auto value = bins[idx];
|
||||
|
||||
for (int i = 1; i < WARP_SIZE; i *= 2)
|
||||
{
|
||||
auto n = __shfl_down_sync(0xFFFFFFFF, value, i);
|
||||
if (inverse_lane_id >= i)
|
||||
value += n;
|
||||
}
|
||||
|
||||
value += previous_group_first_element;
|
||||
bins[idx] = value;
|
||||
|
||||
previous_group_first_element = __shfl_sync(0xFFFFFFFF, value, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
kept_count[b] = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int c = 0; c < num_classes; c++)
|
||||
{
|
||||
if (c == background_class_id)
|
||||
continue;
|
||||
|
||||
const auto indices = indices_.data() + (b * num_classes + c) * classwise_topK;
|
||||
const auto scores = scores_.data() + (b * num_classes + c) * num_priors;
|
||||
|
||||
auto boxes = count[b * num_classes + c];
|
||||
|
||||
for (auto i : block_stride_range<BLOCK_SIZE>(boxes))
|
||||
{
|
||||
auto prior_id = indices[i];
|
||||
const float confidence = load_ldg(scores[prior_id]);
|
||||
if (confidence > threshold)
|
||||
{
|
||||
using device::fast_divide_ftz;
|
||||
auto conf_scaled = fast_divide_ftz(confidence - threshold, 1 - threshold);
|
||||
|
||||
using device::clamp;
|
||||
int bin_index = conf_scaled * BINS;
|
||||
bin_index = clamp<int>(bin_index, 0, BINS - 1);
|
||||
|
||||
const index_type idx = atomicAdd(&bins[bin_index], 1);
|
||||
if (idx < keepTopK)
|
||||
{
|
||||
kept_indices[b * keepTopK + idx] = c * num_priors + prior_id;
|
||||
atomicAdd(&kept_count[b], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void consolidate_detections(Span<T> output,
|
||||
View<int> kept_indices, View<int> kept_count, View<T> decoded_bboxes, View<T> scores, bool share_location,
|
||||
size_type batch_size, size_type num_classes, size_type num_priors, size_type keepTopK, DevicePtr<int> num_detections)
|
||||
{
|
||||
using vector_type = get_vector_type_t<T, 4>;
|
||||
auto decoded_bboxes_vPtr = vector_type::get_pointer(decoded_bboxes.data());
|
||||
|
||||
// output: [1, 1, batch_size * keepTopK, 7]
|
||||
// kept_indices: [batch_size, keepTopK]
|
||||
// kept_count: [batch_size]
|
||||
// decoded_bboxes: [batch_size, num_priors, num_loc_classes, 4]
|
||||
// scores: [batch_size, num_classes, num_priors]
|
||||
|
||||
for (int b = 0; b < batch_size; b++)
|
||||
{
|
||||
for (auto i : grid_stride_range(kept_count[b]))
|
||||
{
|
||||
auto score_id = kept_indices[b * keepTopK + i];
|
||||
auto c = score_id / num_priors;
|
||||
auto prior_id = score_id % num_priors;
|
||||
|
||||
const auto confidence = scores[b * num_classes * num_priors + score_id];
|
||||
|
||||
index_type bbox_id;
|
||||
if (share_location)
|
||||
{
|
||||
// decoded_bboxes: [batch_size, num_priors, 1, 4]
|
||||
bbox_id = b * num_priors + prior_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
// decoded_bboxes: [batch_size, num_priors, num_classes, 4]
|
||||
bbox_id = (b * num_priors + prior_id) * num_classes + c;
|
||||
}
|
||||
|
||||
vector_type bbox;
|
||||
v_load(bbox, decoded_bboxes_vPtr[bbox_id]);
|
||||
|
||||
auto output_id = atomicAdd(num_detections.get(), 1);
|
||||
output[output_id * 7 + 0] = b;
|
||||
output[output_id * 7 + 1] = c;
|
||||
output[output_id * 7 + 2] = confidence;
|
||||
output[output_id * 7 + 3] = bbox.data[0];
|
||||
output[output_id * 7 + 4] = bbox.data[1];
|
||||
output[output_id * 7 + 5] = bbox.data[2];
|
||||
output[output_id * 7 + 6] = bbox.data[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, bool SHARE_LOCATION, bool VARIANCE_ENCODED_IN_TARGET, bool CORNER_TRUE_CENTER_FALSE, bool CLIP_BBOX> static
|
||||
void launch_decode_boxes_kernel(const Stream& stream, Span<T> decoded_bboxes, View<T> locations, View<T> priors,
|
||||
bool transpose_location, bool normalized_bbox,
|
||||
size_type num_loc_classes, index_type background_class_id,
|
||||
float clip_width, float clip_height)
|
||||
{
|
||||
auto kernel = raw::decode_bbox<T, SHARE_LOCATION, VARIANCE_ENCODED_IN_TARGET, CORNER_TRUE_CENTER_FALSE, CLIP_BBOX>;
|
||||
auto policy = make_policy(kernel, decoded_bboxes.size() / 4, 0, stream);
|
||||
launch_kernel(kernel, policy, decoded_bboxes, locations, priors, transpose_location, normalized_bbox, num_loc_classes, background_class_id, clip_width, clip_height);
|
||||
}
|
||||
|
||||
template <class T, std::size_t current, class ...Args> static
|
||||
typename std::enable_if<current == 0, void>
|
||||
::type dispatch_decode_bboxes(int selector, Args&& ...args) {
|
||||
if(selector == 0)
|
||||
launch_decode_boxes_kernel<T, 0, 0, 0, 0>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <class T, std::size_t current, class ...Args> static
|
||||
typename std::enable_if<current != 0, void>
|
||||
::type dispatch_decode_bboxes(int selector, Args&& ...args) {
|
||||
if(selector == current)
|
||||
launch_decode_boxes_kernel<T, current & 8, current & 4, current & 2, current & 1>(std::forward<Args>(args)...);
|
||||
else
|
||||
dispatch_decode_bboxes<T, current - 1, Args...>(selector, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void decode_bboxes(const Stream& stream, Span<T> output, View<T> locations, View<T> priors,
|
||||
std::size_t num_loc_classes,
|
||||
bool share_location, std::size_t background_class_id,
|
||||
bool transpose_location, bool variance_encoded_in_target,
|
||||
bool corner_true_or_center_false, bool normalized_bbox,
|
||||
bool clip_box, float clip_width, float clip_height)
|
||||
{
|
||||
/* `config` combines three kernel template options into one number using which a bit of TMP code can
|
||||
* run through all possible combinations and instantiate the correct template
|
||||
*/
|
||||
unsigned int config = (share_location << 3 | variance_encoded_in_target << 2 | corner_true_or_center_false << 1 | clip_box);
|
||||
dispatch_decode_bboxes<T, 15>(config, stream, output, locations, priors, transpose_location, normalized_bbox, num_loc_classes, background_class_id, clip_width, clip_height);
|
||||
}
|
||||
|
||||
template void decode_bboxes(const Stream&, Span<__half>, View<__half>, View<__half>, std::size_t, bool, std::size_t, bool, bool, bool, bool, bool, float, float);
|
||||
template void decode_bboxes(const Stream&, Span<float>, View<float>, View<float>, std::size_t, bool, std::size_t, bool, bool, bool, bool, bool, float, float);
|
||||
|
||||
template <class T>
|
||||
void findTopK(const Stream& stream, TensorSpan<int> indices, TensorSpan<int> count, TensorView<T> scores, std::size_t background_class_id, float threshold)
|
||||
{
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// scores: [batch_size, num_classes, num_priors]
|
||||
|
||||
const auto batch_size = indices.get_axis_size(0);
|
||||
CV_Assert(count.get_axis_size(0) == batch_size);
|
||||
CV_Assert(scores.get_axis_size(0) == batch_size);
|
||||
|
||||
const auto num_classes = indices.get_axis_size(1);
|
||||
CV_Assert(count.get_axis_size(1) == num_classes);
|
||||
CV_Assert(scores.get_axis_size(1) == num_classes);
|
||||
|
||||
const auto classwise_topK = indices.get_axis_size(2);
|
||||
const auto num_priors = scores.get_axis_size(2);
|
||||
|
||||
/* each block processes one class from each batch */
|
||||
constexpr auto BLOCK_SIZE = 256;
|
||||
|
||||
dim3 grid_size(num_classes, batch_size);
|
||||
dim3 block_size(BLOCK_SIZE);
|
||||
auto policy = execution_policy(grid_size, block_size, stream);
|
||||
|
||||
auto kernel = raw::findTopK<T, 2048, BLOCK_SIZE>;
|
||||
launch_kernel(kernel, policy, indices, count, scores, threshold, classwise_topK, num_classes, num_priors, background_class_id);
|
||||
}
|
||||
|
||||
template void findTopK(const Stream&, TensorSpan<int>, TensorSpan<int>, TensorView<__half>, std::size_t, float);
|
||||
template void findTopK(const Stream&, TensorSpan<int>, TensorSpan<int>, TensorView<float>, std::size_t, float);
|
||||
|
||||
template <class T>
|
||||
void box_collect(const Stream& stream, TensorSpan<T> collected_bboxes, TensorView<T> decoded_bboxes, TensorView<int> indices, TensorView<int> count, bool share_location, std::size_t background_class_id)
|
||||
{
|
||||
// collected_bboxes: [batch_size, num_classes, classwise_topK, 4]
|
||||
// decoded_bboxes: [batch_size, num_priors, num_loc_classes, 4]
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
|
||||
const auto batch_size = collected_bboxes.get_axis_size(0);
|
||||
CV_Assert(decoded_bboxes.get_axis_size(0) == batch_size);
|
||||
CV_Assert(indices.get_axis_size(0) == batch_size);
|
||||
CV_Assert(count.get_axis_size(0) == batch_size);
|
||||
|
||||
const auto num_classes = collected_bboxes.get_axis_size(1);
|
||||
CV_Assert(indices.get_axis_size(1) == num_classes);
|
||||
CV_Assert(count.get_axis_size(1) == num_classes);
|
||||
|
||||
const auto classwise_topK = collected_bboxes.get_axis_size(2);
|
||||
CV_Assert(indices.get_axis_size(2) == classwise_topK);
|
||||
|
||||
const auto num_priors = decoded_bboxes.get_axis_size(1);
|
||||
|
||||
CV_Assert(!share_location || decoded_bboxes.get_axis_size(2) == 1);
|
||||
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
|
||||
/* each block processes one class from each batch */
|
||||
dim3 grid_size(num_classes, batch_size);
|
||||
dim3 block_size(BLOCK_SIZE);
|
||||
auto policy = execution_policy(grid_size, block_size, stream);
|
||||
|
||||
auto kernel = raw::box_collect<T>;
|
||||
launch_kernel(kernel, policy, collected_bboxes, decoded_bboxes, indices, count, share_location, num_priors, num_classes, classwise_topK, background_class_id);
|
||||
}
|
||||
|
||||
template void box_collect(const Stream&, TensorSpan<float>, TensorView<float>, TensorView<int>, TensorView<int>, bool, std::size_t);
|
||||
template void box_collect(const Stream&, TensorSpan<__half>, TensorView<__half>, TensorView<int>, TensorView<int>, bool, std::size_t);
|
||||
|
||||
template <class T>
|
||||
void blockwise_class_nms(const Stream& stream, TensorSpan<int> indices, TensorSpan<int> count, TensorView<T> collected_bboxes,
|
||||
bool normalized_bbox, std::size_t background_class_id, float nms_threshold)
|
||||
{
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// collected_bboxes: [batch_size, num_classes, classwise_topK, 4]
|
||||
|
||||
const auto batch_size = indices.get_axis_size(0);
|
||||
CV_Assert(count.get_axis_size(0) == batch_size);
|
||||
CV_Assert(collected_bboxes.get_axis_size(0) == batch_size);
|
||||
|
||||
const auto num_classes = indices.get_axis_size(1);
|
||||
CV_Assert(count.get_axis_size(1) == num_classes);
|
||||
CV_Assert(collected_bboxes.get_axis_size(1) == num_classes);
|
||||
|
||||
const auto classwise_topK = indices.get_axis_size(2);
|
||||
CV_Assert(collected_bboxes.get_axis_size(2) == classwise_topK);
|
||||
|
||||
/* each block processes one class from each batch */
|
||||
auto num_blocks = batch_size * num_classes;
|
||||
auto num_threads = std::max<std::size_t>(std::min<std::size_t>(1024, classwise_topK), 32);
|
||||
|
||||
dim3 grid_size(num_blocks);
|
||||
dim3 block_size(num_threads);
|
||||
auto policy = execution_policy(grid_size, block_size, stream);
|
||||
|
||||
if (normalized_bbox)
|
||||
{
|
||||
auto kernel = raw::blockwise_class_nms<T, true>;
|
||||
launch_kernel(kernel, policy, indices, count, collected_bboxes, num_classes, classwise_topK, background_class_id, nms_threshold);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto kernel = raw::blockwise_class_nms<T, false>;
|
||||
launch_kernel(kernel, policy, indices, count, collected_bboxes, num_classes, classwise_topK, background_class_id, nms_threshold);
|
||||
}
|
||||
}
|
||||
|
||||
template void blockwise_class_nms(const Stream&, TensorSpan<int>, TensorSpan<int>, TensorView<__half>, bool, std::size_t, float);
|
||||
template void blockwise_class_nms(const Stream&, TensorSpan<int>, TensorSpan<int>, TensorView<float>, bool, std::size_t, float);
|
||||
|
||||
template <class T>
|
||||
void nms_collect(const Stream& stream, TensorSpan<int> kept_indices, TensorSpan<int> kept_count,
|
||||
TensorView<int> indices, TensorView<int> count, TensorView<T> scores, float threshold, std::size_t background_class_id)
|
||||
{
|
||||
// kept_indices: [batch_size, keepTopK]
|
||||
// kept_count: [batch_size]
|
||||
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// scores: [batch_size, num_classes, num_priors]
|
||||
|
||||
auto batch_size = kept_indices.get_axis_size(0);
|
||||
CV_Assert(kept_count.get_axis_size(0) == batch_size);
|
||||
CV_Assert(indices.get_axis_size(0) == batch_size);
|
||||
CV_Assert(count.get_axis_size(0) == batch_size);
|
||||
CV_Assert(scores.get_axis_size(0) == batch_size);
|
||||
|
||||
auto keepTopK = kept_indices.get_axis_size(1);
|
||||
|
||||
auto num_classes = indices.get_axis_size(1);
|
||||
CV_Assert(count.get_axis_size(1) == num_classes);
|
||||
CV_Assert(scores.get_axis_size(1) == num_classes);
|
||||
|
||||
auto classwise_topK = indices.get_axis_size(2);
|
||||
auto num_priors = scores.get_axis_size(2);
|
||||
|
||||
auto num_blocks = batch_size;
|
||||
constexpr int BLOCK_SIZE = 1024;
|
||||
|
||||
dim3 grid_size(num_blocks);
|
||||
dim3 block_size(BLOCK_SIZE);
|
||||
auto policy = execution_policy(grid_size, block_size, stream);
|
||||
|
||||
auto kernel = raw::nms_collect<T, 1024, BLOCK_SIZE>;
|
||||
launch_kernel(kernel, policy, kept_indices, kept_count, indices, count, scores, threshold, num_classes, num_priors, classwise_topK, keepTopK, background_class_id);
|
||||
}
|
||||
|
||||
template void nms_collect(const Stream&, TensorSpan<int>, TensorSpan<int>, TensorView<int>, TensorView<int>, TensorView<__half>, float, std::size_t);
|
||||
template void nms_collect(const Stream&, TensorSpan<int>, TensorSpan<int>, TensorView<int>, TensorView<int>, TensorView<float>, float, std::size_t);
|
||||
|
||||
template <class T>
|
||||
void consolidate_detections(const Stream& stream, TensorSpan<T> output,
|
||||
TensorView<int> kept_indices, TensorView<int> kept_count,
|
||||
TensorView<T> decoded_bboxes, TensorView<T> scores, bool share_location, DevicePtr<int> num_detections)
|
||||
{
|
||||
// output: [1, 1, batch_size * keepTopK, 7]
|
||||
// kept_indices: [batch_size, keepTopK]
|
||||
// kept_count: [batch_size]
|
||||
// decoded_bboxes: [batch_size, num_priors, num_loc_classes, 4]
|
||||
// scores: [batch_size, num_classes, num_priors]
|
||||
|
||||
auto batch_size = kept_indices.get_axis_size(0);
|
||||
CV_Assert(kept_count.get_axis_size(0) == batch_size);
|
||||
CV_Assert(decoded_bboxes.get_axis_size(0) == batch_size);
|
||||
CV_Assert(scores.get_axis_size(0) == batch_size);
|
||||
|
||||
auto keepTopK = kept_indices.get_axis_size(1);
|
||||
|
||||
auto num_classes = scores.get_axis_size(1);
|
||||
auto num_priors = scores.get_axis_size(2);
|
||||
|
||||
CV_Assert(batch_size * keepTopK * 7 == output.size());
|
||||
|
||||
auto kernel = raw::consolidate_detections<T>;
|
||||
auto policy = make_policy(kernel, keepTopK, 0, stream);
|
||||
launch_kernel(kernel, policy, output, kept_indices, kept_count, decoded_bboxes, scores, share_location, batch_size, num_classes, num_priors, keepTopK, num_detections);
|
||||
}
|
||||
|
||||
template void consolidate_detections(const Stream&, TensorSpan<__half>, TensorView<int>, TensorView<int>, TensorView<__half>, TensorView<__half>, bool, DevicePtr<int>);
|
||||
template void consolidate_detections(const Stream&, TensorSpan<float>, TensorView<int>, TensorView<int>, TensorView<float>, TensorView<float>, bool, DevicePtr<int>);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
@ -67,6 +67,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
template void fill(const Stream&, Span<__half>, __half);
|
||||
#endif
|
||||
template void fill(const Stream&, Span<float>, float);
|
||||
template void fill(const Stream&, Span<int>, int);
|
||||
|
||||
template <class T, std::size_t N> static
|
||||
void launch_vectorized_copy(const Stream& stream, Span<T> output, View<T> input) {
|
||||
|
@ -11,7 +11,7 @@
|
||||
|
||||
#include "../cuda4dnn/csl/nvcc_defs.hpp"
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
template <class T>
|
||||
struct IdentityFunctor {
|
||||
|
467
modules/dnn/src/cuda/grid_nms.cu
Normal file
467
modules/dnn/src/cuda/grid_nms.cu
Normal file
@ -0,0 +1,467 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "math.hpp"
|
||||
#include "bbox_utils.hpp"
|
||||
#include "grid_stride_range.hpp"
|
||||
#include "block_stride_range.hpp"
|
||||
#include "execution.hpp"
|
||||
#include "vector_traits.hpp"
|
||||
#include "memory.hpp"
|
||||
|
||||
#include "../cuda4dnn/csl/stream.hpp"
|
||||
#include "../cuda4dnn/csl/span.hpp"
|
||||
#include "../cuda4dnn/csl/tensor.hpp"
|
||||
|
||||
using namespace cv::dnn::cuda4dnn::csl;
|
||||
using namespace cv::dnn::cuda4dnn::csl::device;
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
namespace raw {
|
||||
|
||||
template <class T, bool NORMALIZED_BBOX, int BLOCK_SIZE>
|
||||
__launch_bounds__(BLOCK_SIZE)
|
||||
__global__ void grid_nms(Span<unsigned int> mask_, Span<int> count_, View<T> bboxes_, size_type num_classes, index_type background_class_id, size_type topK, size_type topK_gs, float nms_threshold)
|
||||
{
|
||||
// topK_gs is topK rounded upwards to some size
|
||||
|
||||
// mask: [batch_size, num_classes, topK_gs, topK_gs / 32]
|
||||
// bboxes: [batch_size, num_classes, topK, 4]
|
||||
// count: [batch_size, num_classes]
|
||||
|
||||
const index_type c = blockIdx.y;
|
||||
const index_type b = blockIdx.z;
|
||||
|
||||
if (c == background_class_id)
|
||||
return;
|
||||
|
||||
auto mask = mask_.data() + (b * num_classes + c) * topK_gs * topK_gs / 32;
|
||||
auto bboxes = bboxes_.data() + (b * num_classes + c) * topK * 4;
|
||||
auto count = count_.data() + b * num_classes + c;
|
||||
|
||||
const auto boxes = *count;
|
||||
if (boxes == 0)
|
||||
return;
|
||||
|
||||
/* We divide the set of boxes into groups containing BLOCK_SIZE boxes */
|
||||
const auto num_groups = (boxes + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
/* We need to calculate IOUs for every pair of boxes. We can generalize and say that
|
||||
* we need to compute IOUs of every group with every other group including itself.
|
||||
*/
|
||||
// Each block processes a pair of groups.
|
||||
const index_type group_i = blockIdx.x % num_groups;
|
||||
const index_type group_j = blockIdx.x / num_groups;
|
||||
|
||||
/* we use __syncthreads() later but note that the following condition will cause all threads
|
||||
* in the block to exit; hence, no thread will execute a divergent __syncthreads()
|
||||
*/
|
||||
if (group_i >= num_groups || group_j >= num_groups)
|
||||
return;
|
||||
|
||||
/* Note that IOU(A, B) = IOU(B, A). Hence, if we compute IOU(GROUP_A, GROUP_B), we do not need
|
||||
* to compute IOU(GROUP_B, GROUP_A). We still have to compute IOU(GROUP_A, GROUP_A) though since
|
||||
* each group has many boxes and we need IOUs amongst boxes within a group.
|
||||
*
|
||||
* We arbitarily choose a scheme to exit : exit if group_i is greater than group_j. This way we only
|
||||
* compute IOUs between groups once. While nearly half the blocks are wasted, it's ok since they exit
|
||||
* early on and the working blocks are compute heavy.
|
||||
*/
|
||||
if (group_i > group_j)
|
||||
return;
|
||||
|
||||
/* the following variables contain the absolute box number of the first box of their respective groups */
|
||||
const auto group_i_offset = group_i * BLOCK_SIZE;
|
||||
const auto group_j_offset = group_j * BLOCK_SIZE;
|
||||
|
||||
/* MAIN LOOP LOGIC:
|
||||
* We compare a box `i` from group_i with all boxes in group_j in each iteration. The box `j` is fixed
|
||||
* for each thread. The `j` exactly maps to the thread index. Hence, the `j` is a loop invariant. Each
|
||||
* thread of the block computes the overlap between box `i` and its box `j`.
|
||||
*
|
||||
* for (int i = 0; i < BLOCK_SIZE; i++)
|
||||
* {
|
||||
* // i = box 1
|
||||
* // j = threadIdx.x = box 2
|
||||
* }
|
||||
*/
|
||||
|
||||
/* The `j` box is fixed for each thread. All `i` boxes will be required for every thread.
|
||||
* We store the `i` boxes in shared memory to allow global memory coalesing.
|
||||
*/
|
||||
using vector_type = get_vector_type_t<T, 4>;
|
||||
__shared__ vector_type group_i_boxes[BLOCK_SIZE];
|
||||
|
||||
/* We will precompute the sizes of `i` boxes in the code where we load them. The size computation
|
||||
* is distributed across the block. Otherwise, all threads will have to compute the size of the same
|
||||
* box simultaneously in the main loop. The size is computed while the memory subsystem is busy
|
||||
* servicing requests for box coordinates; the compute resources would otherwise be idle in this phase.
|
||||
*/
|
||||
/* we store the size as a float since the size can exceed fp16 limits for unnormalized boxes */
|
||||
__shared__ float group_i_size[BLOCK_SIZE];
|
||||
|
||||
const auto bboxes_vPtr = vector_type::get_pointer(bboxes);
|
||||
|
||||
// load `i` boxes and precompute their sizes
|
||||
{
|
||||
int i = threadIdx.x;
|
||||
if (group_i_offset + i < boxes)
|
||||
{
|
||||
vector_type box;
|
||||
v_load(box, bboxes_vPtr[group_i_offset + i]);
|
||||
v_store(group_i_boxes[i], box);
|
||||
|
||||
BoundingBox bbox;
|
||||
bbox.xmin = box.data[0];
|
||||
bbox.ymin = box.data[1];
|
||||
bbox.xmax = box.data[2];
|
||||
bbox.ymax = box.data[3];
|
||||
|
||||
group_i_size[i] = compute_bbox_size<NORMALIZED_BBOX>(bbox);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/* We compute overlap between boxes and check if the IOU exceeds the nms threshold.
|
||||
* We store the result (exceeds or below nms_thresold) in a two-dimensional matrix.
|
||||
* (i, j) is set to one if the overlap between i and j is within the nms threshold.
|
||||
* We pack 32 results into one 32-bit integer. The effective memory layout of the
|
||||
* matrix hence is (BLOCK_SIZE, BLOCK_SIZE / 32).
|
||||
*/
|
||||
__shared__ unsigned int mask_shared[BLOCK_SIZE * BLOCK_SIZE / 32];
|
||||
|
||||
// load box `j` and precompute its size (fixed per thread)
|
||||
BoundingBox bbox_j;
|
||||
float bbox_j_size = 0;
|
||||
if (group_j_offset + threadIdx.x < boxes)
|
||||
{
|
||||
vector_type box;
|
||||
v_load(box, bboxes_vPtr[group_j_offset + threadIdx.x]);
|
||||
|
||||
bbox_j.xmin = box.data[0];
|
||||
bbox_j.ymin = box.data[1];
|
||||
bbox_j.xmax = box.data[2];
|
||||
bbox_j.ymax = box.data[3];
|
||||
|
||||
bbox_j_size = compute_bbox_size<NORMALIZED_BBOX>(bbox_j);
|
||||
}
|
||||
|
||||
/* Each thread computes a predicate which is broadcasted across the warp to obtain a 32-bit mask.
|
||||
* The lane zero thread of each warp saves the mask. We store the offset to the mask array beforehand
|
||||
* to save cycles in the compute-intensive main loop.
|
||||
*/
|
||||
auto mask_offset = threadIdx.x / 32;
|
||||
|
||||
/* The main loop is compute intensive and causes the kernel to be overall compute-bound. Hence,
|
||||
* this loop has been highly tuned. Please profile and verify carefully before making changes.
|
||||
*/
|
||||
/* UNROLL_SIZE is the number of boxes that must be processed per iteration. We manually unroll
|
||||
* the loop since the compiler cannot effectively unroll on its own preassumably due to presence
|
||||
* of instructions forcing warp synchronization.
|
||||
*/
|
||||
constexpr int UNROLL_SIZE = 4;
|
||||
|
||||
#pragma unroll 8
|
||||
for (int s = 0; s < BLOCK_SIZE; s += UNROLL_SIZE)
|
||||
{
|
||||
bool do_not_reject_j[UNROLL_SIZE];
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < UNROLL_SIZE; k++)
|
||||
{
|
||||
int i = s + k;
|
||||
|
||||
/* The number of boxes need not necessarily be a multiple of BLOCK_SIZE.
|
||||
* However, the shared memory allocated can hold BLOCK_SIZE boxes from
|
||||
* each group. Accessing the undefined regions of shared memory is
|
||||
* a valid memory operation as long as the memory has been allocated.
|
||||
*
|
||||
* The condition below is only required when one of the groups does not
|
||||
* fully filled with valid boxes. This situations are relatively rare. It's
|
||||
* more common to see both groups completely filled.
|
||||
*
|
||||
* We comment this condition to improve the performance of the common case.
|
||||
* This leads to a net improvement.
|
||||
*/
|
||||
// if (group_i_offset + i < boxes && group_j_offset + threadIdx.x < boxes)
|
||||
{
|
||||
BoundingBox bbox_i;
|
||||
float bbox_i_size;
|
||||
{
|
||||
vector_type box;
|
||||
v_load(box, group_i_boxes[i]);
|
||||
bbox_i.xmin = box.data[0];
|
||||
bbox_i.ymin = box.data[1];
|
||||
bbox_i.xmax = box.data[2];
|
||||
bbox_i.ymax = box.data[3];
|
||||
|
||||
bbox_i_size = group_i_size[i];
|
||||
}
|
||||
|
||||
using device::min;
|
||||
using device::max;
|
||||
|
||||
BoundingBox intersect_bbox;
|
||||
intersect_bbox.xmin = max(bbox_i.xmin, bbox_j.xmin);
|
||||
intersect_bbox.ymin = max(bbox_i.ymin, bbox_j.ymin);
|
||||
intersect_bbox.xmax = min(bbox_i.xmax, bbox_j.xmax);
|
||||
intersect_bbox.ymax = min(bbox_i.ymax, bbox_j.ymax);
|
||||
|
||||
float intersect_size = compute_bbox_size<NORMALIZED_BBOX>(intersect_bbox);
|
||||
|
||||
using device::fast_divide_ftz;
|
||||
float iou = fast_divide_ftz(intersect_size, bbox_i_size + bbox_j_size - intersect_size);
|
||||
do_not_reject_j[k] = iou <= nms_threshold;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < UNROLL_SIZE; k++)
|
||||
{
|
||||
// FORWARD_COMPATIBILITY_TAG: WARP_SIZE_DEPENDENT_CODE
|
||||
auto predicate = __ballot_sync(0xFFFFFFFF, do_not_reject_j[k]);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
mask_shared[mask_offset] = predicate;
|
||||
|
||||
/* The following operation should logically be inside the previous if branch. Note that `mask_offset`
|
||||
* is only used by lane zero threads. Hence, there is no harm in executing it other threads as it is
|
||||
* unused there.
|
||||
*
|
||||
* Keeping it inside prevents the compiler from treating it as a constexpr addition to the address in
|
||||
* successive unrolled iterations. A register is used and instructions are emitted to multiply the
|
||||
* addend by four to obtain the byte offset. Pulling it out of the branch makes the compiler do constexpr
|
||||
* addition on the address in successive unrolled iterations.
|
||||
*/
|
||||
mask_offset += BLOCK_SIZE / 32;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/* The mask data is organized as a two-dimensional bit matrix of size topK_gs * topK_gs.
|
||||
* (i, j) is set to true if the overlap between `i` and `j` is beyond the nms threshold.
|
||||
* We pack 32 results into one 32-bit integer. So the effective memory layout is topK_gs * topK_gs / 32.
|
||||
*/
|
||||
|
||||
/* Each box `i` was compared with BLOCK_SIZE `j` boxes. This amounts to BLOCK_SIZE / 32
|
||||
* 32-bit integers per box `i`.
|
||||
*/
|
||||
using mask_vector_type = get_vector_type_t<unsigned int, BLOCK_SIZE / 32>;
|
||||
|
||||
const int i = threadIdx.x;
|
||||
|
||||
auto mask_shared_vPtr = mask_vector_type::get_pointer(DevicePtr<unsigned>(mask_shared));
|
||||
mask_vector_type temp;
|
||||
v_load(temp, mask_shared_vPtr[i]);
|
||||
for (int i = 0; i < mask_vector_type::size(); i++)
|
||||
temp.data[i] = __brev(temp.data[i]);
|
||||
|
||||
auto mask_vPtr = mask_vector_type::get_pointer(mask);
|
||||
v_store(mask_vPtr[((group_i_offset + i) * topK_gs + group_j_offset) / 32 / mask_vector_type::size()], temp);
|
||||
}
|
||||
|
||||
template <int ITEMS_PER_THREAD, int BLOCK_SIZE>
|
||||
__launch_bounds__(BLOCK_SIZE)
|
||||
__global__ void grid_nms_collect(Span<int> indices_, Span<int> count_, View<unsigned int> mask_, size_type num_classes, index_type background_class_id, size_type topK, size_type topK_gs_by32)
|
||||
{
|
||||
const index_type c = blockIdx.x;
|
||||
if (c == background_class_id)
|
||||
return;
|
||||
|
||||
const index_type b = blockIdx.y;
|
||||
|
||||
// topK_gs is topK rounded upwards to some size
|
||||
|
||||
// indices: [batch_size, num_classes, topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// mask: [batch_size, num_classes, topK_gs, topK_gs / 32]
|
||||
|
||||
auto indices = indices_.data() + (b * num_classes + c) * topK;
|
||||
auto count = count_.data() + (b * num_classes + c);
|
||||
auto mask = mask_.data() + (b * num_classes + c) * topK_gs_by32 * 32 * topK_gs_by32;
|
||||
|
||||
const auto boxes = *count;
|
||||
if (boxes == 0)
|
||||
return;
|
||||
|
||||
/* We have a fixed number of threads and an arbitary number of boxes. We use an array of
|
||||
* bits to store which boxes haven't been eliminated and which are still active. We organize
|
||||
* the array of bits into a matrix of bits of the shape (num_rows, BLOCK_SIZE, 32) which
|
||||
* is equivalent to (num_rows, BLOCK_SIZE) where the type is a 32-bit unsigned integer.
|
||||
* `num_rows` is the minimum number of rows required to cover all the boxes.
|
||||
*
|
||||
* Each thread handles a specific column in the matrix. To improve performance, we process
|
||||
* `ITEMS_PER_THREAD` number of elements per thread. This changes the shape to (num_rows,
|
||||
* ROW_WIDTH) where ROW_WIDTH is BLOCK_SIZE * ITEMS_PER_THREAD.
|
||||
*/
|
||||
constexpr int ROW_WIDTH = BLOCK_SIZE * ITEMS_PER_THREAD;
|
||||
|
||||
const index_type num_32b_masks = static_cast<unsigned>(boxes + 31) / 32;
|
||||
const index_type num_rows = static_cast<unsigned>(num_32b_masks + ROW_WIDTH - 1) / ROW_WIDTH;
|
||||
|
||||
extern __shared__ unsigned int active_boxes[]; // the matrix described earlier
|
||||
|
||||
#pragma unroll 1
|
||||
for (auto idx : block_stride_range<BLOCK_SIZE>(num_32b_masks))
|
||||
active_boxes[idx] = (idx == num_32b_masks - 1) ? __brev((1u << (boxes % 32)) - 1) : 0xFFFFFFFF;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
using vector_type = get_vector_type_t<unsigned int, ITEMS_PER_THREAD>;
|
||||
auto mask_vPtr = vector_type::get_pointer(mask);
|
||||
auto shared_vPtr = vector_type::get_pointer(DevicePtr<unsigned>(active_boxes));
|
||||
|
||||
int index_temp;
|
||||
int thread0_count = 0;
|
||||
int thread_id = threadIdx.x;
|
||||
|
||||
for (int step = 0; step < num_32b_masks; step++)
|
||||
{
|
||||
auto current_active = active_boxes[step];
|
||||
while (current_active)
|
||||
{
|
||||
const index_type bit = __clz(current_active);
|
||||
const index_type i = step * 32 + bit;
|
||||
|
||||
const int mask_offset = static_cast<unsigned>(i * topK_gs_by32) / ITEMS_PER_THREAD;
|
||||
|
||||
/* We fetch the index from the memory and store it in a register. We will not use it until
|
||||
* much later. This helps avoid a long scoreboard stall.
|
||||
*/
|
||||
if (thread_id == 0)
|
||||
index_temp = indices[i];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
active_boxes[step] = current_active ^ (0x80000000 >> bit);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll 1
|
||||
for (int r = 0; r < num_rows; r++)
|
||||
{
|
||||
const int idx = r * BLOCK_SIZE + thread_id;
|
||||
if ((step & ~(ITEMS_PER_THREAD - 1)) <= idx * ITEMS_PER_THREAD && idx * ITEMS_PER_THREAD < num_32b_masks)
|
||||
{
|
||||
auto active_boxes_vec = shared_vPtr[idx];
|
||||
auto mask_vec = mask_vPtr[mask_offset + idx];
|
||||
for (int i = 0; i < vector_type::size(); i++)
|
||||
active_boxes_vec.data[i] &= mask_vec.data[i];
|
||||
shared_vPtr[idx] = active_boxes_vec;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id == 0)
|
||||
{
|
||||
indices[thread0_count] = index_temp;
|
||||
thread0_count++;
|
||||
}
|
||||
|
||||
current_active = active_boxes[step];
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*count = thread0_count;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int GROUP_SIZE = 128;
|
||||
|
||||
static std::size_t getAlignedTopK(std::size_t topK)
|
||||
{
|
||||
auto remainder = topK % GROUP_SIZE;
|
||||
if (remainder == 0)
|
||||
return topK;
|
||||
return topK + (GROUP_SIZE - remainder);
|
||||
}
|
||||
|
||||
std::size_t getGridNMSWorkspaceSizePerBatchItem(std::size_t num_classes, std::size_t classwise_topK)
|
||||
{
|
||||
auto topK_gs = getAlignedTopK(classwise_topK);
|
||||
return num_classes * topK_gs * topK_gs / 32 * sizeof(unsigned int);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void grid_nms(const Stream& stream, Span<unsigned int> workspace, TensorSpan<int> indices, TensorSpan<int> count, TensorView<T> bboxes, int background_class_id, bool normalized_bbox, float nms_threshold)
|
||||
{
|
||||
// workspace: [batch_size, num_classes, topK_gs, topK_gs / 32]
|
||||
// indices: [batch_size, num_classes, topK]
|
||||
// count: [batch_size, num_classes]
|
||||
// bboxes: [batch_size, num_classes, topK, 4] (only first count[b][c] boxes are read)
|
||||
|
||||
const auto batch_size = indices.get_axis_size(0);
|
||||
CV_Assert(count.get_axis_size(0) == batch_size);
|
||||
CV_Assert(bboxes.get_axis_size(0) == batch_size);
|
||||
|
||||
const auto num_classes = indices.get_axis_size(1);
|
||||
CV_Assert(count.get_axis_size(1) == num_classes);
|
||||
CV_Assert(bboxes.get_axis_size(1) == num_classes);
|
||||
|
||||
const auto topK = indices.get_axis_size(2);
|
||||
CV_Assert(bboxes.get_axis_size(2) == topK);
|
||||
|
||||
CV_Assert(bboxes.get_axis_size(3) == 4);
|
||||
|
||||
const auto topK_gs = getAlignedTopK(topK);
|
||||
CV_Assert(workspace.size() >= topK_gs * topK_gs / 32);
|
||||
|
||||
const auto boxes = topK;
|
||||
const auto num_groups = (boxes + GROUP_SIZE - 1) / GROUP_SIZE;
|
||||
|
||||
{
|
||||
// grid = (num_groups * num_groups, num_classes, batch_size)
|
||||
// if the background class is the last class, we can reduce grid y dim by one
|
||||
auto grid_num_classes = num_classes; //(background_class_id == num_classes - 1) ? num_classes - 1 : num_classes;
|
||||
|
||||
constexpr int BLOCK_SIZE = GROUP_SIZE;
|
||||
|
||||
dim3 grid_size(num_groups * num_groups, grid_num_classes, batch_size);
|
||||
dim3 block_size(BLOCK_SIZE);
|
||||
auto policy = execution_policy(grid_size, block_size, stream);
|
||||
|
||||
if (normalized_bbox)
|
||||
{
|
||||
auto kernel = raw::grid_nms<T, true, BLOCK_SIZE>;
|
||||
launch_kernel(kernel, policy, workspace, count, bboxes, num_classes, background_class_id, topK, topK_gs, nms_threshold);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto kernel = raw::grid_nms<T, false, BLOCK_SIZE>;
|
||||
launch_kernel(kernel, policy, workspace, count, bboxes, num_classes, background_class_id, topK, topK_gs, nms_threshold);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// grid = (num_classes, batch_size)
|
||||
// if the background class is the last class, we can reduce grid x dim by one
|
||||
auto grid_num_classes = num_classes; //(background_class_id == num_classes - 1) ? num_classes - 1 : num_classes;
|
||||
|
||||
constexpr int BLOCK_SIZE = 64;
|
||||
|
||||
constexpr int ITEMS_PER_THREAD = 4;
|
||||
auto kernel = raw::grid_nms_collect<ITEMS_PER_THREAD, BLOCK_SIZE>;
|
||||
|
||||
dim3 grid_size(grid_num_classes, batch_size);
|
||||
|
||||
auto sharedMem = topK_gs / 32 * 4;
|
||||
auto policy = execution_policy(grid_size, BLOCK_SIZE, sharedMem, stream);
|
||||
launch_kernel(kernel, policy, indices, count, workspace, num_classes, background_class_id, topK, topK_gs / 32);
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t getGridNMSWorkspaceSizePerBatchItem(std::size_t num_classes, std::size_t classwise_topK);
|
||||
|
||||
template void grid_nms(const Stream& stream, Span<unsigned int> workspace, TensorSpan<int> indices, TensorSpan<int> count, TensorView<__half> bboxes, int, bool normalized_bbox, float nms_threshold);
|
||||
template void grid_nms(const Stream& stream, Span<unsigned int> workspace, TensorSpan<int> indices, TensorSpan<int> count, TensorView<float> bboxes, int, bool normalized_bbox, float nms_threshold);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
@ -6,90 +6,62 @@
|
||||
#define OPENCV_DNN_SRC_CUDA_GRID_STRIDE_RANGE_HPP
|
||||
|
||||
#include "types.hpp"
|
||||
#include "index_helpers.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
|
||||
|
||||
namespace detail {
|
||||
using dim3_member_type = decltype(dim3::x);
|
||||
template <int dim, class index_type = device::index_type, class size_type = device::size_type>
|
||||
class grid_stride_range_generic {
|
||||
public:
|
||||
__device__ grid_stride_range_generic(index_type to_) : from(0), to(to_) { }
|
||||
__device__ grid_stride_range_generic(index_type from_, index_type to_) : from(from_), to(to_) { }
|
||||
|
||||
template <int> __device__ dim3_member_type getGridDim();
|
||||
template <> inline __device__ dim3_member_type getGridDim<0>() { return gridDim.x; }
|
||||
template <> inline __device__ dim3_member_type getGridDim<1>() { return gridDim.y; }
|
||||
template <> inline __device__ dim3_member_type getGridDim<2>() { return gridDim.z; }
|
||||
|
||||
template <int> __device__ dim3_member_type getBlockDim();
|
||||
template <> inline __device__ dim3_member_type getBlockDim<0>() { return blockDim.x; }
|
||||
template <> inline __device__ dim3_member_type getBlockDim<1>() { return blockDim.y; }
|
||||
template <> inline __device__ dim3_member_type getBlockDim<2>() { return blockDim.z; }
|
||||
|
||||
using uint3_member_type = decltype(uint3::x);
|
||||
|
||||
template <int> __device__ uint3_member_type getBlockIdx();
|
||||
template <> inline __device__ uint3_member_type getBlockIdx<0>() { return blockIdx.x; }
|
||||
template <> inline __device__ uint3_member_type getBlockIdx<1>() { return blockIdx.y; }
|
||||
template <> inline __device__ uint3_member_type getBlockIdx<2>() { return blockIdx.z; }
|
||||
|
||||
template <int> __device__ uint3_member_type getThreadIdx();
|
||||
template <> inline __device__ uint3_member_type getThreadIdx<0>() { return threadIdx.x; }
|
||||
template <> inline __device__ uint3_member_type getThreadIdx<1>() { return threadIdx.y; }
|
||||
template <> inline __device__ uint3_member_type getThreadIdx<2>() { return threadIdx.z; }
|
||||
}
|
||||
|
||||
template <int dim, class index_type = device::index_type, class size_type = device::size_type>
|
||||
class grid_stride_range_generic {
|
||||
class iterator
|
||||
{
|
||||
public:
|
||||
__device__ grid_stride_range_generic(index_type to_) : from(0), to(to_) { }
|
||||
__device__ grid_stride_range_generic(index_type from_, index_type to_) : from(from_), to(to_) { }
|
||||
__device__ iterator(index_type pos_) : pos(pos_) {}
|
||||
|
||||
class iterator
|
||||
{
|
||||
public:
|
||||
__device__ iterator(index_type pos_) : pos(pos_) {}
|
||||
/* these iterators return the index when dereferenced; this allows us to loop
|
||||
* through the indices using a range based for loop
|
||||
*/
|
||||
__device__ index_type operator*() const { return pos; }
|
||||
|
||||
/* these iterators return the index when dereferenced; this allows us to loop
|
||||
* through the indices using a range based for loop
|
||||
*/
|
||||
__device__ index_type operator*() const { return pos; }
|
||||
|
||||
__device__ iterator& operator++() {
|
||||
pos += detail::getGridDim<dim>() * static_cast<index_type>(detail::getBlockDim<dim>());
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ bool operator!=(const iterator& other) const {
|
||||
/* NOTE HACK
|
||||
** 'pos' can move in large steps (see operator++)
|
||||
** expansion of range for loop uses != as the loop conditioion
|
||||
** => operator!= must return false if 'pos' crosses the end
|
||||
*/
|
||||
return pos < other.pos;
|
||||
}
|
||||
|
||||
private:
|
||||
index_type pos;
|
||||
};
|
||||
|
||||
__device__ iterator begin() const {
|
||||
using detail::getBlockDim;
|
||||
using detail::getBlockIdx;
|
||||
using detail::getThreadIdx;
|
||||
return iterator(from + getBlockDim<dim>() * getBlockIdx<dim>() + getThreadIdx<dim>());
|
||||
__device__ iterator& operator++() {
|
||||
pos += getGridDim<dim>() * static_cast<index_type>(getBlockDim<dim>());
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ iterator end() const {
|
||||
return iterator(to);
|
||||
__device__ bool operator!=(const iterator& other) const {
|
||||
/* NOTE HACK
|
||||
* 'pos' can move in large steps (see operator++)
|
||||
* expansion of range for loop uses != as the loop conditioion
|
||||
* => operator!= must return false if 'pos' crosses the end
|
||||
*/
|
||||
return pos < other.pos;
|
||||
}
|
||||
|
||||
private:
|
||||
index_type from, to;
|
||||
index_type pos;
|
||||
};
|
||||
|
||||
using grid_stride_range_x = grid_stride_range_generic<0>;
|
||||
using grid_stride_range_y = grid_stride_range_generic<1>;
|
||||
using grid_stride_range_z = grid_stride_range_generic<2>;
|
||||
using grid_stride_range = grid_stride_range_x;
|
||||
__device__ iterator begin() const {
|
||||
return iterator(from + getBlockDim<dim>() * getBlockIdx<dim>() + getThreadIdx<dim>());
|
||||
}
|
||||
|
||||
__device__ iterator end() const {
|
||||
return iterator(to);
|
||||
}
|
||||
|
||||
private:
|
||||
index_type from, to;
|
||||
};
|
||||
|
||||
using grid_stride_range_x = grid_stride_range_generic<0>;
|
||||
using grid_stride_range_y = grid_stride_range_generic<1>;
|
||||
using grid_stride_range_z = grid_stride_range_generic<2>;
|
||||
using grid_stride_range = grid_stride_range_x;
|
||||
|
||||
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
|
||||
|
||||
|
41
modules/dnn/src/cuda/index_helpers.hpp
Normal file
41
modules/dnn/src/cuda/index_helpers.hpp
Normal file
@ -0,0 +1,41 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA_INDEX_HELPERS_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA_INDEX_HELPERS_HPP
|
||||
|
||||
#include "types.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
|
||||
|
||||
namespace detail {
|
||||
using dim3_member_type = decltype(dim3::x);
|
||||
using uint3_member_type = decltype(uint3::x);
|
||||
}
|
||||
|
||||
template <int> __device__ detail::dim3_member_type getGridDim();
|
||||
template <> inline __device__ detail::dim3_member_type getGridDim<0>() { return gridDim.x; }
|
||||
template <> inline __device__ detail::dim3_member_type getGridDim<1>() { return gridDim.y; }
|
||||
template <> inline __device__ detail::dim3_member_type getGridDim<2>() { return gridDim.z; }
|
||||
|
||||
template <int> __device__ detail::dim3_member_type getBlockDim();
|
||||
template <> inline __device__ detail::dim3_member_type getBlockDim<0>() { return blockDim.x; }
|
||||
template <> inline __device__ detail::dim3_member_type getBlockDim<1>() { return blockDim.y; }
|
||||
template <> inline __device__ detail::dim3_member_type getBlockDim<2>() { return blockDim.z; }
|
||||
|
||||
template <int> __device__ detail::uint3_member_type getBlockIdx();
|
||||
template <> inline __device__ detail::uint3_member_type getBlockIdx<0>() { return blockIdx.x; }
|
||||
template <> inline __device__ detail::uint3_member_type getBlockIdx<1>() { return blockIdx.y; }
|
||||
template <> inline __device__ detail::uint3_member_type getBlockIdx<2>() { return blockIdx.z; }
|
||||
|
||||
template <int> __device__ detail::uint3_member_type getThreadIdx();
|
||||
template <> inline __device__ detail::uint3_member_type getThreadIdx<0>() { return threadIdx.x; }
|
||||
template <> inline __device__ detail::uint3_member_type getThreadIdx<1>() { return threadIdx.y; }
|
||||
template <> inline __device__ detail::uint3_member_type getThreadIdx<2>() { return threadIdx.z; }
|
||||
|
||||
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA_INDEX_HELPERS_HPP */
|
@ -122,9 +122,23 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
|
||||
template <> inline __device__ __half ceil(__half value) { return hceil(value); }
|
||||
#endif
|
||||
|
||||
template <class T> __device__ T mul_ftz(T x, T y) { return x * y; }
|
||||
template <> inline __device__ float mul_ftz(float x, float y) {
|
||||
float result;
|
||||
asm("mul.ftz.f32 %0, %1, %2;" : "=f"(result) : "f"(x), "f"(y));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T> __device__ T fast_divide(T x, T y) { return x / y; }
|
||||
template <> inline __device__ float fast_divide(float x, float y) { return __fdividef(x, y); }
|
||||
|
||||
template <class T> __device__ T fast_divide_ftz(T x, T y) { return fast_divide(x, y); }
|
||||
template <> inline __device__ float fast_divide_ftz(float x, float y) {
|
||||
float result;
|
||||
asm("div.approx.ftz.f32 %0, %1, %2;" : "=f"(result) : "f"(x), "f"(y));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T> __device__ T fast_exp(T value) { return exp(value); }
|
||||
template <> inline __device__ float fast_exp(float value) { return __expf(value); }
|
||||
|
||||
|
@ -780,7 +780,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
}
|
||||
|
||||
template <class ForwardItr>
|
||||
TensorView(pointer ptr_, ForwardItr start, ForwardItr end) : ptr{ ptr_ } {
|
||||
TensorView(const_pointer ptr_, ForwardItr start, ForwardItr end) : ptr{ ptr_ } {
|
||||
CV_Assert(start != end);
|
||||
CV_Assert(std::distance(start, end) <= CSL_MAX_TENSOR_RANK);
|
||||
|
||||
|
42
modules/dnn/src/cuda4dnn/kernels/detection_output.hpp
Normal file
42
modules/dnn/src/cuda4dnn/kernels/detection_output.hpp
Normal file
@ -0,0 +1,42 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_DETECTION_OUTPUT_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_DETECTION_OUTPUT_HPP
|
||||
|
||||
#include "../csl/stream.hpp"
|
||||
#include "../csl/span.hpp"
|
||||
#include "../csl/tensor.hpp"
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
template <class T>
|
||||
void decode_bboxes(const csl::Stream& stream, csl::Span<T> output, csl::View<T> locations, csl::View<T> priors,
|
||||
std::size_t num_loc_classes, bool share_location, std::size_t background_label_id,
|
||||
bool transpose_location, bool variance_encoded_in_target,
|
||||
bool corner_true_or_center_false, bool normalized_bbox,
|
||||
bool clip_box, float clip_width, float clip_height);
|
||||
|
||||
template <class T>
|
||||
void findTopK(const csl::Stream& stream, csl::TensorSpan<int> indices, csl::TensorSpan<int> count, csl::TensorView<T> scores, std::size_t background_label_id, float threshold);
|
||||
|
||||
template <class T>
|
||||
void box_collect(const csl::Stream& stream, csl::TensorSpan<T> collected_bboxes, csl::TensorView<T> decoded_bboxes, csl::TensorView<int> indices, csl::TensorView<int> count, bool share_location, std::size_t background_label_id);
|
||||
|
||||
template <class T>
|
||||
void blockwise_class_nms(const csl::Stream& stream, csl::TensorSpan<int> indices, csl::TensorSpan<int> count, csl::TensorView<T> collected_bboxes,
|
||||
bool normalized_bbox, std::size_t background_label_id, float nms_threshold);
|
||||
|
||||
template <class T>
|
||||
void nms_collect(const csl::Stream& stream, csl::TensorSpan<int> kept_indices, csl::TensorSpan<int> kept_count,
|
||||
csl::TensorView<int> indices, csl::TensorView<int> count, csl::TensorView<T> scores, float, std::size_t background_label_id);
|
||||
|
||||
template <class T>
|
||||
void consolidate_detections(const csl::Stream& stream, csl::TensorSpan<T> output,
|
||||
csl::TensorView<int> kept_indices, csl::TensorView<int> kept_count,
|
||||
csl::TensorView<T> decoded_bboxes, csl::TensorView<T> scores, bool share_location, csl::DevicePtr<int> num_detections);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_DETECTION_OUTPUT_HPP */
|
21
modules/dnn/src/cuda4dnn/kernels/grid_nms.hpp
Normal file
21
modules/dnn/src/cuda4dnn/kernels/grid_nms.hpp
Normal file
@ -0,0 +1,21 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_GRID_NMS_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_GRID_NMS_HPP
|
||||
|
||||
#include "../csl/stream.hpp"
|
||||
#include "../csl/span.hpp"
|
||||
#include "../csl/tensor.hpp"
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
std::size_t getGridNMSWorkspaceSizePerBatchItem(std::size_t num_classes, std::size_t classwise_topK);
|
||||
|
||||
template <class T>
|
||||
void grid_nms(const csl::Stream& stream, csl::Span<unsigned int> workspace, csl::TensorSpan<int> indices, csl::TensorSpan<int> count, csl::TensorView<T> bboxes, int background_class_id, bool normalized_bbox, float nms_threshold);
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_GRID_NMS_HPP */
|
282
modules/dnn/src/cuda4dnn/primitives/detection_output.hpp
Normal file
282
modules/dnn/src/cuda4dnn/primitives/detection_output.hpp
Normal file
@ -0,0 +1,282 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_DETECTION_OUTPUT_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_DETECTION_OUTPUT_HPP
|
||||
|
||||
#include "../../op_cuda.hpp"
|
||||
|
||||
#include "../csl/stream.hpp"
|
||||
#include "../csl/tensor.hpp"
|
||||
|
||||
#include "../kernels/fill_copy.hpp"
|
||||
#include "../kernels/permute.hpp"
|
||||
#include "../kernels/detection_output.hpp"
|
||||
#include "../kernels/grid_nms.hpp"
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
|
||||
struct DetectionOutputConfiguration {
|
||||
std::size_t batch_size;
|
||||
|
||||
enum class CodeType {
|
||||
CORNER,
|
||||
CENTER_SIZE
|
||||
};
|
||||
CodeType code_type;
|
||||
|
||||
bool share_location;
|
||||
std::size_t num_priors;
|
||||
std::size_t num_classes;
|
||||
std::size_t background_class_id;
|
||||
|
||||
bool transpose_location;
|
||||
bool variance_encoded_in_target;
|
||||
bool normalized_bbox;
|
||||
bool clip_box;
|
||||
|
||||
std::size_t classwise_topK;
|
||||
float confidence_threshold;
|
||||
float nms_threshold;
|
||||
|
||||
int keepTopK;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class DetectionOutputOp final : public CUDABackendNode {
|
||||
private:
|
||||
/* We have block level NMS kernel where each block handles one class of one batch item.
|
||||
* If the number of classes and batch size together is very low, the blockwise NMS kernel
|
||||
* won't able to fully saturate the GPU with work.
|
||||
*
|
||||
* We also have a grid level NMS kernel where multiple blocks handle each class of every batch item.
|
||||
* This performs better in the worst case and utilizes resources better when block level kernel isn't
|
||||
* able to saturate the GPU with enough work. However, this is not efficient in the average case where
|
||||
* the block level kernel is able to saturate the GPU. It does better when the blockwise NMS barely
|
||||
* saturates the GPU.
|
||||
*
|
||||
* `GRID_NMS_CUTOFF` is the cutoff for `num_classes * batch_size` above which we will switch from grid
|
||||
* level NMS to block level NMS.
|
||||
*/
|
||||
static constexpr int GRID_NMS_CUTOFF = 32;
|
||||
|
||||
public:
|
||||
using wrapper_type = GetCUDABackendWrapperType<T>;
|
||||
|
||||
DetectionOutputOp(csl::Stream stream_, const DetectionOutputConfiguration& config)
|
||||
: stream(std::move(stream_))
|
||||
{
|
||||
corner_true_or_center_false = (config.code_type == DetectionOutputConfiguration::CodeType::CORNER);
|
||||
|
||||
share_location = config.share_location;
|
||||
num_priors = config.num_priors;
|
||||
num_classes = config.num_classes;
|
||||
background_class_id = config.background_class_id;
|
||||
|
||||
transpose_location = config.transpose_location;
|
||||
variance_encoded_in_target = config.variance_encoded_in_target;
|
||||
normalized_bbox = config.normalized_bbox;
|
||||
clip_box = config.clip_box;
|
||||
|
||||
classwise_topK = config.classwise_topK;
|
||||
confidence_threshold = config.confidence_threshold;
|
||||
nms_threshold = config.nms_threshold;
|
||||
|
||||
keepTopK = config.keepTopK;
|
||||
CV_Assert(keepTopK > 0);
|
||||
|
||||
if (classwise_topK == -1)
|
||||
{
|
||||
classwise_topK = num_priors;
|
||||
if (keepTopK > 0 && keepTopK < num_priors)
|
||||
classwise_topK = keepTopK;
|
||||
}
|
||||
|
||||
auto batch_size = config.batch_size;
|
||||
auto num_loc_classes = (share_location ? 1 : num_classes);
|
||||
|
||||
csl::WorkspaceBuilder builder;
|
||||
builder.require<T>(batch_size * num_priors * num_loc_classes * 4); /* decoded boxes */
|
||||
builder.require<T>(batch_size * num_classes * num_priors); /* transposed scores */
|
||||
builder.require<int>(batch_size * num_classes * classwise_topK); /* indices */
|
||||
builder.require<int>(batch_size * num_classes); /* classwise topK count */
|
||||
builder.require<T>(batch_size * num_classes * classwise_topK * 4); /* topK decoded boxes */
|
||||
|
||||
if (batch_size * num_classes <= GRID_NMS_CUTOFF)
|
||||
{
|
||||
auto workspace_per_batch_item = kernels::getGridNMSWorkspaceSizePerBatchItem(num_classes, classwise_topK);
|
||||
builder.require(batch_size * workspace_per_batch_item);
|
||||
}
|
||||
|
||||
builder.require<int>(batch_size * keepTopK); /* final kept indices */
|
||||
builder.require<int>(batch_size); /* kept indices count */
|
||||
builder.require<int>(1); /* total number of detections */
|
||||
|
||||
scratch_mem_in_bytes = builder.required_workspace_size();
|
||||
}
|
||||
|
||||
void forward(
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
|
||||
csl::Workspace& workspace) override
|
||||
{
|
||||
/* locations, scores and priors make the first three inputs in order */
|
||||
/* the 4th input is used to obtain the shape for clipping */
|
||||
CV_Assert((inputs.size() == 3 || inputs.size() == 4) && outputs.size() == 1);
|
||||
|
||||
// locations: [batch_size, num_priors, num_loc_classes, 4]
|
||||
auto locations_wrapper = inputs[0].dynamicCast<wrapper_type>();
|
||||
auto locations = locations_wrapper->getView();
|
||||
|
||||
// scores: [batch_size, num_priors, num_classes]
|
||||
auto scores_wrapper = inputs[1].dynamicCast<wrapper_type>();
|
||||
auto scores = scores_wrapper->getView();
|
||||
scores.unsqueeze();
|
||||
scores.reshape(-1, num_priors, num_classes);
|
||||
|
||||
// priors: [1, 2, num_priors, 4]
|
||||
auto priors_wrapper = inputs[2].dynamicCast<wrapper_type>();
|
||||
auto priors = priors_wrapper->getView();
|
||||
|
||||
// output: [1, 1, batch_size * keepTopK, 7]
|
||||
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
|
||||
auto output = output_wrapper->getSpan();
|
||||
|
||||
auto batch_size = locations.get_axis_size(0);
|
||||
auto num_loc_classes = (share_location ? 1 : num_classes);
|
||||
while(locations.rank() < 4)
|
||||
locations.unsqueeze();
|
||||
locations.reshape(batch_size, num_priors, num_loc_classes, 4);
|
||||
|
||||
float clip_width = 0.0, clip_height = 0.0;
|
||||
if (clip_box)
|
||||
{
|
||||
if (normalized_bbox)
|
||||
{
|
||||
clip_width = clip_height = 1.0f;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto image_wrapper = inputs[3].dynamicCast<wrapper_type>();
|
||||
auto image_shape = image_wrapper->getShape();
|
||||
|
||||
CV_Assert(image_shape.size() == 4);
|
||||
clip_width = image_shape[3] - 1;
|
||||
clip_height = image_shape[2] - 1;
|
||||
}
|
||||
}
|
||||
|
||||
csl::WorkspaceAllocator allocator(workspace);
|
||||
|
||||
// decoded_boxes: [batch_size, num_priors, num_loc_classes, 4]
|
||||
csl::TensorSpan<T> decoded_boxes;
|
||||
{
|
||||
auto shape = std::vector<std::size_t>{batch_size, num_priors, num_loc_classes, 4};
|
||||
decoded_boxes = allocator.get_tensor_span<T>(std::begin(shape), std::end(shape));
|
||||
CV_Assert(is_shape_same(decoded_boxes, locations));
|
||||
}
|
||||
|
||||
kernels::decode_bboxes<T>(stream, decoded_boxes, locations, priors,
|
||||
num_loc_classes, share_location, background_class_id,
|
||||
transpose_location, variance_encoded_in_target,
|
||||
corner_true_or_center_false, normalized_bbox,
|
||||
clip_box, clip_width, clip_height);
|
||||
|
||||
// scores_permuted: [batch_size, num_classes, num_priors]
|
||||
csl::TensorSpan<T> scores_permuted;
|
||||
{
|
||||
auto shape = std::vector<std::size_t>{batch_size, num_classes, num_priors};
|
||||
scores_permuted = allocator.get_tensor_span<T>(std::begin(shape), std::end(shape));
|
||||
}
|
||||
|
||||
kernels::permute<T>(stream, scores_permuted, scores, {0, 2, 1});
|
||||
|
||||
// indices: [batch_size, num_classes, classwise_topK]
|
||||
csl::TensorSpan<int> indices;
|
||||
{
|
||||
auto shape = std::vector<std::size_t>{batch_size, num_classes, classwise_topK};
|
||||
indices = allocator.get_tensor_span<int>(std::begin(shape), std::end(shape));
|
||||
}
|
||||
|
||||
// count: [batch_size, num_classes]
|
||||
csl::TensorSpan<int> count;
|
||||
{
|
||||
auto shape = std::vector<std::size_t>{batch_size, num_classes};
|
||||
count = allocator.get_tensor_span<int>(std::begin(shape), std::end(shape));
|
||||
}
|
||||
|
||||
kernels::findTopK<T>(stream, indices, count, scores_permuted, background_class_id, confidence_threshold);
|
||||
|
||||
// collected_bboxes: [batch_size, num_classes, classwise_topK, 4]
|
||||
csl::TensorSpan<T> collected_bboxes;
|
||||
{
|
||||
auto shape = std::vector<std::size_t>{batch_size, num_classes, classwise_topK, 4};
|
||||
collected_bboxes = allocator.get_tensor_span<T>(std::begin(shape), std::end(shape));
|
||||
}
|
||||
|
||||
kernels::box_collect<T>(stream, collected_bboxes, decoded_boxes, indices, count, share_location, background_class_id);
|
||||
|
||||
if (batch_size * num_classes <= GRID_NMS_CUTOFF)
|
||||
{
|
||||
auto workspace_per_batch_item = kernels::getGridNMSWorkspaceSizePerBatchItem(num_classes, classwise_topK);
|
||||
auto workspace = allocator.get_span<unsigned int>(batch_size * workspace_per_batch_item / sizeof(unsigned int));
|
||||
kernels::grid_nms<T>(stream, workspace, indices, count, collected_bboxes, background_class_id, normalized_bbox, nms_threshold);
|
||||
}
|
||||
else
|
||||
{
|
||||
kernels::blockwise_class_nms<T>(stream, indices, count, collected_bboxes, normalized_bbox, background_class_id, nms_threshold);
|
||||
}
|
||||
|
||||
// kept_indices: [batch_size, keepTopK]
|
||||
csl::TensorSpan<int> kept_indices;
|
||||
{
|
||||
auto shape = std::vector<std::size_t>{batch_size, static_cast<std::size_t>(keepTopK)};
|
||||
kept_indices = allocator.get_tensor_span<int>(std::begin(shape), std::end(shape));
|
||||
}
|
||||
|
||||
// kept_count: [batch_size]
|
||||
csl::TensorSpan<int> kept_count;
|
||||
{
|
||||
auto shape = std::vector<std::size_t>{batch_size};
|
||||
kept_count = allocator.get_tensor_span<int>(std::begin(shape), std::end(shape));
|
||||
}
|
||||
|
||||
kernels::nms_collect<T>(stream, kept_indices, kept_count, indices, count, scores_permuted, confidence_threshold, background_class_id);
|
||||
|
||||
auto num_detections = allocator.get_span<int>(1);
|
||||
kernels::fill<int>(stream, num_detections, 0);
|
||||
kernels::fill<T>(stream, output, 0.0);
|
||||
kernels::consolidate_detections<T>(stream, output, kept_indices, kept_count, decoded_boxes, scores_permuted, share_location, num_detections.data());
|
||||
}
|
||||
|
||||
std::size_t get_workspace_memory_in_bytes() const noexcept override { return scratch_mem_in_bytes; }
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
std::size_t scratch_mem_in_bytes;
|
||||
|
||||
bool share_location;
|
||||
std::size_t num_priors;
|
||||
std::size_t num_classes;
|
||||
std::size_t background_class_id;
|
||||
|
||||
bool transpose_location;
|
||||
bool variance_encoded_in_target;
|
||||
bool corner_true_or_center_false;
|
||||
bool normalized_bbox;
|
||||
bool clip_box;
|
||||
|
||||
std::size_t classwise_topK;
|
||||
float confidence_threshold;
|
||||
float nms_threshold;
|
||||
|
||||
int keepTopK;
|
||||
};
|
||||
|
||||
}}} /* namespace cv::dnn::cuda4dnn */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_DETECTION_OUTPUT_HPP */
|
@ -42,6 +42,7 @@
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include "layers_common.hpp"
|
||||
#include "../op_cuda.hpp"
|
||||
#include "../op_inf_engine.hpp"
|
||||
|
||||
#include <float.h>
|
||||
@ -59,7 +60,11 @@
|
||||
#else
|
||||
#include <ngraph/op/experimental/layers/detection_output.hpp>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
#include "../cuda4dnn/primitives/detection_output.hpp"
|
||||
using namespace cv::dnn::cuda4dnn;
|
||||
#endif
|
||||
|
||||
namespace cv
|
||||
@ -195,7 +200,7 @@ public:
|
||||
_locPredTransposed = getParameter<bool>(params, "loc_pred_transposed", 0, false, false);
|
||||
_bboxesNormalized = getParameter<bool>(params, "normalized_bbox", 0, false, true);
|
||||
_clip = getParameter<bool>(params, "clip", 0, false, false);
|
||||
_groupByClasses = getParameter<bool>(params, "group_by_classes", 0, false, true);
|
||||
_groupByClasses = getParameter<bool>(params, "group_by_classes", 0, false, false);
|
||||
|
||||
getCodeType(params);
|
||||
|
||||
@ -209,6 +214,7 @@ public:
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
(backendId == DNN_BACKEND_CUDA && !_groupByClasses) ||
|
||||
((backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) && !_locPredTransposed && _bboxesNormalized);
|
||||
}
|
||||
|
||||
@ -929,6 +935,56 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
Ptr<BackendNode> initCUDA(
|
||||
void *context_,
|
||||
const std::vector<Ptr<BackendWrapper>>& inputs,
|
||||
const std::vector<Ptr<BackendWrapper>>& outputs
|
||||
) override
|
||||
{
|
||||
auto context = reinterpret_cast<csl::CSLContext*>(context_);
|
||||
|
||||
auto locations_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
|
||||
auto locations_shape = locations_wrapper->getShape();
|
||||
|
||||
auto priors_wrapper = inputs[2].dynamicCast<CUDABackendWrapper>();
|
||||
auto priors_shape = priors_wrapper->getShape();
|
||||
|
||||
cuda4dnn::DetectionOutputConfiguration config;
|
||||
config.batch_size = locations_shape[0];
|
||||
|
||||
if (_codeType == "CORNER")
|
||||
{
|
||||
config.code_type = cuda4dnn::DetectionOutputConfiguration::CodeType::CORNER;
|
||||
}
|
||||
else if(_codeType == "CENTER_SIZE")
|
||||
{
|
||||
config.code_type = cuda4dnn::DetectionOutputConfiguration::CodeType::CENTER_SIZE;
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, _codeType + " code type not supported by CUDA backend in DetectionOutput layer");
|
||||
}
|
||||
|
||||
config.share_location = _shareLocation;
|
||||
config.num_priors = priors_shape[2] / 4;
|
||||
config.num_classes = _numClasses;
|
||||
config.background_class_id = _backgroundLabelId;
|
||||
|
||||
config.transpose_location = _locPredTransposed;
|
||||
config.variance_encoded_in_target = _varianceEncodedInTarget;
|
||||
config.normalized_bbox = _bboxesNormalized;
|
||||
config.clip_box = _clip;
|
||||
|
||||
config.classwise_topK = _topK;
|
||||
config.confidence_threshold = _confidenceThreshold;
|
||||
config.nms_threshold = _nmsThreshold;
|
||||
|
||||
config.keepTopK = _keepTopK;
|
||||
return make_cuda_node<cuda4dnn::DetectionOutputOp>(preferableTarget, std::move(context->stream), config);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_DNN_IE_NN_BUILDER_2019
|
||||
virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> >&) CV_OVERRIDE
|
||||
{
|
||||
|
@ -262,6 +262,7 @@ TEST_P(Test_Model, DetectionMobilenetSSD)
|
||||
else if (target == DNN_TARGET_CUDA_FP16)
|
||||
{
|
||||
scoreDiff = 4e-4;
|
||||
iouDiff = 1e-2;
|
||||
}
|
||||
float confThreshold = FLT_MIN;
|
||||
double nmsThreshold = 0.0;
|
||||
@ -352,6 +353,11 @@ TEST_P(Test_Model, Detection_normalized)
|
||||
double scoreDiff = 1e-5, iouDiff = 1e-5;
|
||||
float confThreshold = FLT_MIN;
|
||||
double nmsThreshold = 0.0;
|
||||
if (target == DNN_TARGET_CUDA)
|
||||
{
|
||||
scoreDiff = 3e-4;
|
||||
iouDiff = 0.018;
|
||||
}
|
||||
if (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD || target == DNN_TARGET_CUDA_FP16)
|
||||
{
|
||||
scoreDiff = 5e-3;
|
||||
|
@ -691,6 +691,13 @@ TEST_P(Test_TensorFlow_nets, Faster_RCNN)
|
||||
checkBackend();
|
||||
|
||||
double scoresDiff = backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 ? 2.9e-5 : 1e-5;
|
||||
double iouDiff = 1e-4;
|
||||
if (target == DNN_TARGET_CUDA)
|
||||
{
|
||||
// for faster_rcnn_resnet50_coco_2018_01_28
|
||||
scoresDiff = 0.06;
|
||||
iouDiff = 0.08;
|
||||
}
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
std::string proto = findDataFile("dnn/" + names[i] + ".pbtxt");
|
||||
@ -706,7 +713,7 @@ TEST_P(Test_TensorFlow_nets, Faster_RCNN)
|
||||
Mat out = net.forward();
|
||||
|
||||
Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/" + names[i] + ".detection_out.npy"));
|
||||
normAssertDetections(ref, out, names[i].c_str(), 0.3, scoresDiff);
|
||||
normAssertDetections(ref, out, names[i].c_str(), 0.3, scoresDiff, iouDiff);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1244,7 +1251,7 @@ TEST_P(Test_TensorFlow_nets, EfficientDet)
|
||||
if (target == DNN_TARGET_CUDA_FP16)
|
||||
{
|
||||
scoreDiff = 0.002;
|
||||
iouDiff = 0.003;
|
||||
iouDiff = 0.004;
|
||||
}
|
||||
normAssertDetections(ref, out, "", 0.5, scoreDiff, iouDiff);
|
||||
expectNoFallbacksFromIE(net);
|
||||
|
Loading…
Reference in New Issue
Block a user