mirror of
https://github.com/opencv/opencv.git
synced 2024-11-27 04:36:36 +08:00
Merge pull request #17200 from YashasSamaga:cuda4dnn-general-opt1
cuda4dnn: optimizations for swish, mish, sigmoid, region, resize based ops, transpose, identity-conv fusion * bunch of optimizations * more accurate implementation for mish
This commit is contained in:
parent
666be238d8
commit
d981d04c76
@ -9,6 +9,7 @@
|
||||
#include "types.hpp"
|
||||
#include "grid_stride_range.hpp"
|
||||
#include "execution.hpp"
|
||||
#include "memory.hpp"
|
||||
|
||||
#include "../cuda4dnn/csl/stream.hpp"
|
||||
#include "../cuda4dnn/csl/tensor.hpp"
|
||||
@ -102,10 +103,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
#pragma unroll 1 /* disable unrolling */
|
||||
for (int i = 0; i < CHANNELS_PER_ITER; i++) {
|
||||
auto v_00 = input[in_offset_r0 + in_x0],
|
||||
v_01 = input[in_offset_r0 + in_x1],
|
||||
v_10 = input[in_offset_r1 + in_x0],
|
||||
v_11 = input[in_offset_r1 + in_x1];
|
||||
auto v_00 = load_ldg(input[in_offset_r0 + in_x0]),
|
||||
v_01 = load_ldg(input[in_offset_r0 + in_x1]),
|
||||
v_10 = load_ldg(input[in_offset_r1 + in_x0]),
|
||||
v_11 = load_ldg(input[in_offset_r1 + in_x1]);
|
||||
|
||||
output[out_idx] =
|
||||
v_00 +
|
||||
|
@ -30,8 +30,10 @@ struct tanh_functor {
|
||||
template <class T>
|
||||
struct swish_functor {
|
||||
__device__ T operator()(T value) {
|
||||
using csl::device::sigmoid;
|
||||
return value * sigmoid(value);
|
||||
// f(x) = x * sigmoid(x)
|
||||
using csl::device::fast_divide;
|
||||
using csl::device::fast_exp;
|
||||
return fast_divide(value, static_cast<T>(1) + fast_exp(-value));
|
||||
}
|
||||
};
|
||||
|
||||
@ -44,11 +46,30 @@ struct mish_functor {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mish_functor<float> {
|
||||
__device__ float operator()(float value) {
|
||||
// f(x) = x * tanh(log1pexp(x));
|
||||
using csl::device::fast_divide;
|
||||
using csl::device::fast_exp;
|
||||
|
||||
auto e = fast_exp(value);
|
||||
if (value <= -18.0f)
|
||||
return value * e;
|
||||
|
||||
auto n = e * e + 2 * e;
|
||||
if (value <= -5.0f)
|
||||
return value * fast_divide(n, n + 2);
|
||||
|
||||
return value - 2 * fast_divide(value, n + 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct sigmoid_functor {
|
||||
__device__ T operator()(T value) {
|
||||
using csl::device::sigmoid;
|
||||
return sigmoid(value);
|
||||
using csl::device::fast_sigmoid;
|
||||
return fast_sigmoid(value);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -160,6 +160,15 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
|
||||
template <> inline __device__ __half2 ceil(__half2 value) { return h2ceil(value); }
|
||||
#endif
|
||||
|
||||
template <class T> __device__ T fast_divide(T x, T y) { return x / y; }
|
||||
template <> inline __device__ float fast_divide(float x, float y) { return __fdividef(x, y); }
|
||||
|
||||
template <class T> __device__ T fast_exp(T value) { return exp(value); }
|
||||
template <> inline __device__ float fast_exp(float value) { return __expf(value); }
|
||||
|
||||
template <class T> __device__ T fast_sigmoid(T value) { return sigmoid(value); }
|
||||
template <> inline __device__ float fast_sigmoid(float value) { return __fdividef(1, 1 + __expf(-value)); }
|
||||
|
||||
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA_MATH_HPP */
|
||||
|
32
modules/dnn/src/cuda/memory.hpp
Normal file
32
modules/dnn/src/cuda/memory.hpp
Normal file
@ -0,0 +1,32 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA_MEMORY_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA_MEMORY_HPP
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
|
||||
|
||||
template <class T>
|
||||
__device__ T load_ldg(const T& src) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)
|
||||
return __ldg(&src);
|
||||
#else
|
||||
return src;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__device__ T load_ldg(const T* src) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)
|
||||
return __ldg(src);
|
||||
#else
|
||||
return *src;
|
||||
#endif
|
||||
}
|
||||
|
||||
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA_MEMORY_HPP */
|
@ -7,7 +7,6 @@
|
||||
|
||||
#include "array.hpp"
|
||||
#include "types.hpp"
|
||||
#include "vector_traits.hpp"
|
||||
#include "grid_stride_range.hpp"
|
||||
#include "execution.hpp"
|
||||
#include "kernel_dispatcher.hpp"
|
||||
@ -50,82 +49,60 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, int TILE_SIZE, std::size_t N>
|
||||
template <class T, int TILE_SIZE, int ROWS_PER_THREAD>
|
||||
__global__ void transpose(Span<T> output, View<T> input, size_type in_width, size_type out_width)
|
||||
{
|
||||
using vector_type = get_vector_type_t<T, N>;
|
||||
|
||||
__shared__ T tile[TILE_SIZE][TILE_SIZE + 1];
|
||||
|
||||
/* blockDim.y = TILE_SIZE, blockDim.x = TILE_SIZE/N */
|
||||
const index_type in_x = blockIdx.x * TILE_SIZE + threadIdx.x * N;
|
||||
const index_type in_y = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
/* blockDim.y = TILE_SIZE / ROWS_PER_THREAD, blockDim.x = TILE_SIZE */
|
||||
const index_type in_x = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
const index_type in_y_begin = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
|
||||
/* Every valid input location has a corresponding output location and vice versa.
|
||||
* Hence, if we do not load values into the shared memory for a given location, we
|
||||
* also won't read them for storing in the output.
|
||||
*/
|
||||
if (in_x < in_width && in_y < out_width)
|
||||
for (int j = 0; j < TILE_SIZE; j += TILE_SIZE / ROWS_PER_THREAD)
|
||||
{
|
||||
vector_type vec;
|
||||
auto input_vPtr = vector_type::get_pointer(input.data());
|
||||
v_load(vec, input_vPtr[(in_y * in_width + in_x) / N]);
|
||||
|
||||
for (int i = 0; i < vector_type::size(); i++)
|
||||
tile[threadIdx.y][threadIdx.x * N + i] = vec.data[i];
|
||||
const auto in_y_current = in_y_begin + j;
|
||||
if (in_x < in_width && in_y_current < out_width)
|
||||
tile[threadIdx.y + j][threadIdx.x] = input[in_y_current * in_width + in_x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/* Note that `blockDim.x * N` is equal to `blockDim.y`. Since there are an equal
|
||||
* number of them, we can interchange `threadIdx.x` and `threadIdx.y` without changing
|
||||
* result. The advantage of interchanging is that consecutive output indices map to
|
||||
/* We interchange `threadIdx.x` and `threadIdx.y` so that consecutive output indices map to
|
||||
* consecutive threads. This would allow writes across threds in a warp to be coalesced.
|
||||
*/
|
||||
const index_type out_x = blockIdx.y * TILE_SIZE + threadIdx.x * N;
|
||||
const index_type out_y = blockIdx.x * TILE_SIZE + threadIdx.y;
|
||||
const index_type out_x = blockIdx.y * TILE_SIZE + threadIdx.x;
|
||||
const index_type out_y_begin = blockIdx.x * TILE_SIZE + threadIdx.y;
|
||||
|
||||
if (out_x < out_width && out_y < in_width)
|
||||
for (int j = 0; j < TILE_SIZE; j += TILE_SIZE / ROWS_PER_THREAD)
|
||||
{
|
||||
vector_type vec;
|
||||
for (int i = 0; i < vector_type::size(); i++)
|
||||
vec.data[i] = tile[threadIdx.x * N + i][threadIdx.y];
|
||||
|
||||
auto output_vPtr = vector_type::get_pointer(output.data());
|
||||
v_store(output_vPtr[(out_y * out_width + out_x) / N], vec);
|
||||
const auto out_y_current = out_y_begin + j;
|
||||
if (out_x < out_width && out_y_current < in_width)
|
||||
output[out_y_current * out_width + out_x] = tile[threadIdx.x][threadIdx.y + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, std::size_t N> static
|
||||
void launch_transpose_kernel(const Stream& stream, Span<T> output, View<T> input, size_type in_width, size_type out_width)
|
||||
{
|
||||
CV_Assert(is_fully_aligned<T>(output, N));
|
||||
CV_Assert(is_fully_aligned<T>(input, N));
|
||||
CV_Assert(in_width % N == 0);
|
||||
CV_Assert(out_width % N == 0);
|
||||
|
||||
constexpr int TILE_SIZE = 32;
|
||||
constexpr int TILE_SIZE_X = TILE_SIZE/N, TILE_SIZE_Y = TILE_SIZE;
|
||||
auto kernel = raw::transpose<T, TILE_SIZE, N>;
|
||||
|
||||
dim3 grid_size((in_width/N + TILE_SIZE_X - 1)/TILE_SIZE_X, (out_width + TILE_SIZE_Y - 1)/TILE_SIZE_Y);
|
||||
dim3 block_size(TILE_SIZE_X, TILE_SIZE_Y);
|
||||
auto policy = execution_policy(grid_size, block_size, stream);
|
||||
|
||||
launch_kernel(kernel, policy, output, input, in_width, out_width);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void transpose(const Stream& stream, Span<T> output, View<T> input, std::size_t in_width, std::size_t out_width)
|
||||
{
|
||||
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && in_width % 4 == 0 && out_width % 4 == 0) {
|
||||
launch_transpose_kernel<T, 4>(stream, output, input, in_width, out_width);
|
||||
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && in_width % 2 == 0 && out_width % 2 == 0) {
|
||||
launch_transpose_kernel<T, 2>(stream, output, input, in_width, out_width);
|
||||
} else {
|
||||
launch_transpose_kernel<T, 1>(stream, output, input, in_width, out_width);
|
||||
}
|
||||
/* Each block processes a TILE_SIZE x TILE_SIZE piece */
|
||||
constexpr int TILE_SIZE = 32;
|
||||
|
||||
/* Each thread processes ROWS_PER_THREAD rows. We do this to decrease the number of threads required
|
||||
* in a block so that the cost of the block-wide synchronization is minimized.
|
||||
*/
|
||||
constexpr int ROWS_PER_THREAD = 4;
|
||||
|
||||
dim3 grid_size((in_width + TILE_SIZE - 1) / TILE_SIZE, (out_width + TILE_SIZE - 1) / TILE_SIZE);
|
||||
dim3 block_size(TILE_SIZE, TILE_SIZE / ROWS_PER_THREAD);
|
||||
auto policy = execution_policy(grid_size, block_size, stream);
|
||||
|
||||
auto kernel = raw::transpose<T, TILE_SIZE, ROWS_PER_THREAD>;
|
||||
launch_kernel(kernel, policy, output, input, in_width, out_width);
|
||||
}
|
||||
|
||||
template void transpose(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t);
|
||||
|
@ -47,20 +47,20 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
const auto y = (box_index % batch_inner_size) / row_inner_size;
|
||||
const auto x = (box_index % row_inner_size) / col_inner_size;
|
||||
|
||||
using device::sigmoid;
|
||||
output[box_offset + 0] = (T(x) + sigmoid(input[box_offset + 0])) / T(cols);
|
||||
output[box_offset + 1] = (T(y) + sigmoid(input[box_offset + 1])) / T(rows);
|
||||
using device::fast_sigmoid;
|
||||
output[box_offset + 0] = (T(x) + fast_sigmoid(input[box_offset + 0])) / T(cols);
|
||||
output[box_offset + 1] = (T(y) + fast_sigmoid(input[box_offset + 1])) / T(rows);
|
||||
|
||||
vector2_type bias_xy;
|
||||
v_load(bias_xy, bias_vPtr[box_of_the_cell]);
|
||||
|
||||
using device::exp;
|
||||
output[box_offset + 2] = exp(input[box_offset + 2]) * bias_xy.data[0] / T(width_norm);
|
||||
output[box_offset + 3] = exp(input[box_offset + 3]) * bias_xy.data[1] / T(height_norm);
|
||||
using device::fast_exp;
|
||||
output[box_offset + 2] = fast_exp(input[box_offset + 2]) * bias_xy.data[0] / T(width_norm);
|
||||
output[box_offset + 3] = fast_exp(input[box_offset + 3]) * bias_xy.data[1] / T(height_norm);
|
||||
|
||||
/* squash objectness score into a probability */
|
||||
using device::sigmoid;
|
||||
T objectness_prob = sigmoid(input[box_offset + 4]);
|
||||
using device::fast_sigmoid;
|
||||
T objectness_prob = fast_sigmoid(input[box_offset + 4]);
|
||||
|
||||
/* ignore prediction if the objectness probability is less than the cutoff */
|
||||
if (objectness_prob < object_prob_cutoff)
|
||||
@ -91,7 +91,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
* to obtain the actual class probability, we multiply the conditional probability
|
||||
* with the object probability
|
||||
*/
|
||||
auto actual_class_prob = objectness_prob * sigmoid(input[idx]);
|
||||
using device::fast_sigmoid;
|
||||
auto actual_class_prob = objectness_prob * fast_sigmoid(input[idx]);
|
||||
if (actual_class_prob <= class_prob_cutoff)
|
||||
actual_class_prob = T(0);
|
||||
output[idx] = actual_class_prob;
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "types.hpp"
|
||||
#include "grid_stride_range.hpp"
|
||||
#include "execution.hpp"
|
||||
#include "memory.hpp"
|
||||
|
||||
#include "../cuda4dnn/csl/stream.hpp"
|
||||
#include "../cuda4dnn/csl/tensor.hpp"
|
||||
@ -70,7 +71,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
index_type out_idx = c_start * out_image_size + y * out_width + x;
|
||||
|
||||
for (int i = 0; i < CHANNELS_PER_ITER; i++) {
|
||||
output[out_idx] = input[in_idx];
|
||||
output[out_idx] = load_ldg(input[in_idx]);
|
||||
|
||||
in_idx += in_image_size;
|
||||
out_idx += out_image_size;
|
||||
@ -134,10 +135,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
|
||||
#pragma unroll 1 /* disable unrolling to reduce register pressure; not sure how but it works */
|
||||
for (auto c = c_start; c < c_end; c++) {
|
||||
auto v_00 = input[in_offset_r0 + in_x0],
|
||||
v_01 = input[in_offset_r0 + in_x1],
|
||||
v_10 = input[in_offset_r1 + in_x0],
|
||||
v_11 = input[in_offset_r1 + in_x1];
|
||||
auto v_00 = load_ldg(input[in_offset_r0 + in_x0]),
|
||||
v_01 = load_ldg(input[in_offset_r0 + in_x1]),
|
||||
v_10 = load_ldg(input[in_offset_r1 + in_x0]),
|
||||
v_11 = load_ldg(input[in_offset_r1 + in_x1]);
|
||||
|
||||
output[out_idx] =
|
||||
v_00 +
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "types.hpp"
|
||||
#include "grid_stride_range.hpp"
|
||||
#include "execution.hpp"
|
||||
#include "memory.hpp"
|
||||
|
||||
#include "../cuda4dnn/csl/stream.hpp"
|
||||
#include "../cuda4dnn/csl/tensor.hpp"
|
||||
@ -118,7 +119,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
const auto in_idx = in_offset + iy * in_width;
|
||||
for (auto ix = x_start; ix < x_end; ix++)
|
||||
{
|
||||
max_val = max(max_val, input[in_idx + ix]);
|
||||
max_val = max(max_val, load_ldg(input[in_idx + ix]));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "types.hpp"
|
||||
#include "memory.hpp"
|
||||
|
||||
#include "../cuda4dnn/csl/pointer.hpp"
|
||||
|
||||
@ -86,6 +87,16 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
|
||||
dest.raw = src->raw;
|
||||
}
|
||||
|
||||
template <class V>
|
||||
__device__ void v_load_ldg(V& dest, const V& src) {
|
||||
dest.raw = load_ldg(src.raw);
|
||||
}
|
||||
|
||||
template <class V>
|
||||
__device__ void v_load_ldg(V& dest, const V* src) {
|
||||
dest.raw = load_ldg(src->raw);
|
||||
}
|
||||
|
||||
template <class V>
|
||||
__device__ void v_store(V* dest, const V& src) {
|
||||
dest->raw = src.raw;
|
||||
|
@ -167,6 +167,10 @@ public:
|
||||
|
||||
virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
|
||||
{
|
||||
Ptr<BlankLayer> blank_layer = top.dynamicCast<BlankLayer>();
|
||||
if (blank_layer)
|
||||
return true;
|
||||
|
||||
Mat w, b;
|
||||
top->getScaleShift(w, b);
|
||||
if (!w.empty() || !b.empty())
|
||||
|
Loading…
Reference in New Issue
Block a user