opencv/modules/dnn/src/cuda/prior_box.cu

175 lines
7.6 KiB
Plaintext
Raw Normal View History

Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low CUDA backend for the DNN module * stub cuda4dnn design * minor fixes for tests and doxygen * add csl public api directory to module headers * add low-level CSL components * add high-level CSL components * integrate csl::Tensor into backbone code * switch to CPU iff unsupported; otherwise, fail on error * add fully connected layer * add softmax layer * add activation layers * support arbitary rank TensorDescriptor * pass input wrappers to `initCUDA()` * add 1d/2d/3d-convolution * add pooling layer * reorganize and refactor code * fixes for gcc, clang and doxygen; remove cxx14/17 code * add blank_layer * add LRN layer * add rounding modes for pooling layer * split tensor.hpp into tensor.hpp and tensor_ops.hpp * add concat layer * add scale layer * add batch normalization layer * split math.cu into activations.cu and math.hpp * add eltwise layer * add flatten layer * add tensor transform api * add asymmetric padding support for convolution layer * add reshape layer * fix rebase issues * add permute layer * add padding support for concat layer * refactor and reorganize code * add normalize layer * optimize bias addition in scale layer * add prior box layer * fix and optimize normalize layer * add asymmetric padding support for pooling layer * add event API * improve pooling performance for some padding scenarios * avoid over-allocation of compute resources to kernels * improve prior box performance * enable layer fusion * add const layer * add resize layer * add slice layer * add padding layer * add deconvolution layer * fix channelwise ReLU initialization * add vector traits * add vectorized versions of relu, clipped_relu, power * add vectorized concat kernels * improve concat_with_offsets performance * vectorize scale and bias kernels * add support for multi-billion element tensors * vectorize prior box kernels * fix address alignment check * improve bias addition performance of conv/deconv/fc layers * restructure code for supporting multiple targets * add DNN_TARGET_CUDA_FP64 * add DNN_TARGET_FP16 * improve vectorization * add region layer * improve tensor API, add dynamic ranks 1. use ManagedPtr instead of a Tensor in backend wrapper 2. add new methods to tensor classes - size_range: computes the combined size of for a given axis range - tensor span/view can be constructed from a raw pointer and shape 3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time) 4. remove device code from tensor classes (as they are unused) 5. enforce strict conditions on tensor class APIs to improve debugging ability * fix parametric relu activation * add squeeze/unsqueeze tensor API * add reorg layer * optimize permute and enable 2d permute * enable 1d and 2d slice * add split layer * add shuffle channel layer * allow tensors of different ranks in reshape primitive * patch SliceOp to allow Crop Layer * allow extra shape inputs in reshape layer * use `std::move_backward` instead of `std::move` for insert in resizable_static_array * improve workspace management * add spatial LRN * add nms (cpu) to region layer * add max pooling with argmax ( and a fix to limits.hpp) * add max unpooling layer * rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA * update supportBackend to be more rigorous * remove stray include from preventing non-cuda build * include op_cuda.hpp outside condition #if * refactoring, fixes and many optimizations * drop DNN_TARGET_CUDA_FP64 * fix gcc errors * increase max. tensor rank limit to six * add Interp layer * drop custom layers; use BackendNode * vectorize activation kernels * fixes for gcc * remove wrong assertion * fix broken assertion in unpooling primitive * fix build errors in non-CUDA build * completely remove workspace from public API * fix permute layer * enable accuracy and perf. tests for DNN_TARGET_CUDA * add asynchronous forward * vectorize eltwise ops * vectorize fill kernel * fixes for gcc * remove CSL headers from public API * remove csl header source group from cmake * update min. cudnn version in cmake * add numerically stable FP32 log1pexp * refactor code * add FP16 specialization to cudnn based tensor addition * vectorize scale1 and bias1 + minor refactoring * fix doxygen build * fix invalid alignment assertion * clear backend wrappers before allocateLayers * ignore memory lock failures * do not allocate internal blobs * integrate NVTX * add numerically stable half precision log1pexp * fix indentation, following coding style, improve docs * remove accidental modification of IE code * Revert "add asynchronous forward" This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70. * [cmake] throw error for unsupported CC versions * fix rebase issues * add more docs, refactor code, fix bugs * minor refactoring and fixes * resolve warnings/errors from clang * remove haveCUDA() checks from supportBackend() * remove NVTX integration * changes based on review comments * avoid exception when no CUDA device is present * add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "array.hpp"
#include "math.hpp"
#include "types.hpp"
#include "vector_traits.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <cstddef>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, bool Normalize>
__global__ void prior_box(
Span<T> output,
View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
size_type layerWidth, size_type layerHeight,
size_type imageWidth, size_type imageHeight)
{
/* each box consists of two pair of coordinates and hence 4 values in total */
/* since the entire output consists (first channel at least) of these boxes,
* we are garunteeed that the output is aligned to a boundary of 4 values
*/
using vector_type = get_vector_type_t<T, 4>;
auto output_vPtr = vector_type::get_pointer(output.data());
/* num_points contains the number of points in the feature map of interest
* each iteration of the stride loop selects a point and generates prior boxes for it
*/
size_type num_points = layerWidth * layerHeight;
for (auto idx : grid_stride_range(num_points)) {
const index_type x = idx % layerWidth,
y = idx / layerWidth;
index_type output_offset_v4 = idx * offsetX.size() * boxWidth.size();
for (int i = 0; i < boxWidth.size(); i++) {
for (int j = 0; j < offsetX.size(); j++) {
float center_x = (x + offsetX[j]) * stepX;
float center_y = (y + offsetY[j]) * stepY;
vector_type vec;
if(Normalize) {
vec.data[0] = (center_x - boxWidth[i] * 0.5f) / imageWidth;
vec.data[1] = (center_y - boxHeight[i] * 0.5f) / imageHeight;
vec.data[2] = (center_x + boxWidth[i] * 0.5f) / imageWidth;
vec.data[3] = (center_y + boxHeight[i] * 0.5f) / imageHeight;
} else {
vec.data[0] = center_x - boxWidth[i] * 0.5f;
vec.data[1] = center_y - boxHeight[i] * 0.5f;
vec.data[2] = center_x + boxWidth[i] * 0.5f - 1.0f;
vec.data[3] = center_y + boxHeight[i] * 0.5f - 1.0f;
}
v_store(output_vPtr[output_offset_v4], vec);
output_offset_v4++;
}
}
}
}
template <class T>
__global__ void prior_box_clip(Span<T> output) {
for (auto i : grid_stride_range(output.size())) {
using device::clamp;
output[i] = clamp<T>(output[i], 0.0, 1.0);
}
}
template <class T>
__global__ void prior_box_set_variance1(Span<T> output, float variance) {
using vector_type = get_vector_type_t<T, 4>;
auto output_vPtr = vector_type::get_pointer(output.data());
for (auto i : grid_stride_range(output.size() / 4)) {
vector_type vec;
for (int j = 0; j < 4; j++)
vec.data[j] = variance;
v_store(output_vPtr[i], vec);
}
}
template <class T>
__global__ void prior_box_set_variance4(Span<T> output, array<float, 4> variance) {
using vector_type = get_vector_type_t<T, 4>;
auto output_vPtr = vector_type::get_pointer(output.data());
for (auto i : grid_stride_range(output.size() / 4)) {
vector_type vec;
for(int j = 0; j < 4; j++)
vec.data[j] = variance[j];
v_store(output_vPtr[i], vec);
}
}
}
template <class T, bool Normalize> static
void launch_prior_box_kernel(
const Stream& stream,
Span<T> output, View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
std::size_t layerWidth, std::size_t layerHeight, std::size_t imageWidth, std::size_t imageHeight)
{
auto num_points = layerWidth * layerHeight;
auto kernel = raw::prior_box<T, Normalize>;
auto policy = make_policy(kernel, num_points, 0, stream);
launch_kernel(kernel, policy,
output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
layerWidth, layerHeight, imageWidth, imageHeight);
}
template <class T>
void generate_prior_boxes(
const Stream& stream,
Span<T> output,
View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
std::vector<float> variance,
std::size_t numPriors,
std::size_t layerWidth, std::size_t layerHeight,
std::size_t imageWidth, std::size_t imageHeight,
bool normalize, bool clip)
{
if (normalize) {
launch_prior_box_kernel<T, true>(
stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
layerWidth, layerHeight, imageWidth, imageHeight
);
} else {
launch_prior_box_kernel<T, false>(
stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
layerWidth, layerHeight, imageWidth, imageHeight
);
}
std::size_t channel_size = layerHeight * layerWidth * numPriors * 4;
CV_Assert(channel_size * 2 == output.size());
if (clip) {
auto output_span_c1 = Span<T>(output.data(), channel_size);
auto kernel = raw::prior_box_clip<T>;
auto policy = make_policy(kernel, output_span_c1.size(), 0, stream);
launch_kernel(kernel, policy, output_span_c1);
}
auto output_span_c2 = Span<T>(output.data() + channel_size, channel_size);
if (variance.size() == 1) {
auto kernel = raw::prior_box_set_variance1<T>;
auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream);
launch_kernel(kernel, policy, output_span_c2, variance[0]);
} else {
array<float, 4> variance_k;
variance_k.assign(std::begin(variance), std::end(variance));
auto kernel = raw::prior_box_set_variance4<T>;
auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream);
launch_kernel(kernel, policy, output_span_c2, variance_k);
}
}
template void generate_prior_boxes(const Stream&, Span<__half>, View<float>, View<float>, View<float>, View<float>, float, float,
std::vector<float>, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool);
template void generate_prior_boxes(const Stream&, Span<float>, View<float>, View<float>, View<float>, View<float>, float, float,
std::vector<float>, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */