diff --git a/modules/dnn/src/cuda/activations.cu b/modules/dnn/src/cuda/activations.cu index 143361c1f3..221516dddc 100644 --- a/modules/dnn/src/cuda/activations.cu +++ b/modules/dnn/src/cuda/activations.cu @@ -5,7 +5,7 @@ #include #include -#include "math.hpp" +#include "functors.hpp" #include "types.hpp" #include "vector_traits.hpp" #include "grid_stride_range.hpp" @@ -25,519 +25,178 @@ using namespace cv::dnn::cuda4dnn::csl::device; namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { - namespace raw { - template - __global__ void abs_vec(Span output, View input) { - using vector_type = get_vector_type_t; +namespace raw { + template + __global__ void generic_op_vec(Span output, View input, FunctorArgs ...functorArgs) { + using vector_type = get_vector_type_t; - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) { - using device::abs; - vec.data[j] = abs(vec.data[j]); - } - v_store(output_vPtr[i], vec); - } - } + Functor functor(functorArgs...); - template - __global__ void tanh_vec(Span output, View input) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) { - using device::tanh; - vec.data[j] = tanh(vec.data[j]); - } - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void swish_vec(Span output, View input) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) { - using device::sigmoid; - vec.data[j] = vec.data[j] * sigmoid(vec.data[j]); - } - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void mish_vec(Span output, View input) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) { - using device::tanh; - using device::log1pexp; - vec.data[j] = vec.data[j] * tanh(log1pexp(vec.data[j])); - } - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void sigmoid_vec(Span output, View input) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) { - using device::sigmoid; - vec.data[j] = sigmoid(vec.data[j]); - } - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void bnll_vec(Span output, View input) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) { - using device::log1pexp; - vec.data[j] = vec.data[j] > T(0) ? vec.data[j] + log1pexp(-vec.data[j]) : log1pexp(vec.data[j]); - } - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void elu_vec(Span output, View input) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) { - using device::expm1; - vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : expm1(vec.data[j]); - } - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void relu_vec(Span output, View input, T slope) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for(int j = 0; j < vector_type::size(); j++) - vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : slope * vec.data[j]; - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void clipped_relu_vec(Span output, View input, T floor, T ceiling) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - using device::clamp; - - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) - vec.data[j] = clamp(vec.data[j], floor, ceiling); - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void axiswise_relu_vec(Span output, View input, size_type inner_size, View slope) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - inner_size /= vector_type::size(); - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - const index_type c = (i / inner_size) % static_cast(slope.size()); - - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) - vec.data[j] = vec.data[j] > T(0) ? vec.data[j] : vec.data[j] * slope[c]; - v_store(output_vPtr[i], vec); - } - } - - template - __global__ void power_vec(Span output, View input, T exp, T scale, T shift) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - using device::pow; - - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vector_type::size(); j++) - vec.data[j] = pow(shift + scale * vec.data[j], exp); - v_store(output_vPtr[i], vec); - } + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) + vec.data[j] = functor(vec.data[j]); + v_store(output_vPtr[i], vec); } } template - void launch_vectorized_abs(const Stream& stream, Span output, View input) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); + __global__ void axiswise_relu_vec(Span output, View input, size_type inner_size, View slope) { + using vector_type = get_vector_type_t; - auto kernel = raw::abs_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input); - } + auto output_vPtr = vector_type::get_pointer(output.data()); + auto input_vPtr = vector_type::get_pointer(input.data()); - template - void abs(const Stream& stream, Span output, View input) { - CV_Assert(input.size() == output.size()); + inner_size /= vector_type::size(); + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + const index_type c = (i / inner_size) % static_cast(slope.size()); - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_abs(stream, output, input); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_abs(stream, output, input); - } else { - launch_vectorized_abs(stream, output, input); + vector_type vec; + v_load(vec, input_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) + vec.data[j] = vec.data[j] > T(0) ? vec.data[j] : vec.data[j] * slope[c]; + v_store(output_vPtr[i], vec); } } +} /* namespace raw */ + +template class Activation, std::size_t N, class ...ActivationArgs> static +void launch_vectorized_generic_op(const Stream& stream, Span output, View input, ActivationArgs ...activationArgs) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + + auto kernel = raw::generic_op_vec, N, ActivationArgs...>; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, activationArgs...); +} + +template class Activation, class ...ActivationArgs> static +void generic_op(const Stream& stream, Span output, View input, ActivationArgs ...activationArgs) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { + launch_vectorized_generic_op(stream, output, input, activationArgs...); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { + launch_vectorized_generic_op(stream, output, input, activationArgs...); + } else { + launch_vectorized_generic_op(stream, output, input, activationArgs...); + } +} + +template +void abs(const Stream& stream, Span output, View input) { + generic_op(stream, output, input); +} + +template +void tanh(const Stream& stream, Span output, View input) { + generic_op(stream, output, input); +} + +template +void swish(const Stream& stream, Span output, View input) { + generic_op(stream, output, input); +} + +template +void mish(const Stream& stream, Span output, View input) { + generic_op(stream, output, input); +} + +template +void sigmoid(const Stream& stream, Span output, View input) { + generic_op(stream, output, input); +} + +template +void bnll(const Stream& stream, Span output, View input) { + generic_op(stream, output, input); +} + +template +void elu(const Stream& stream, Span output, View input) { + generic_op(stream, output, input); +} + +template +void relu(const Stream& stream, Span output, View input, T slope) { + generic_op(stream, output, input, slope); +} + +template +void clipped_relu(const Stream& stream, Span output, View input, T floor, T ceiling) { + CV_Assert(static_cast(floor) <= static_cast(ceiling)); + generic_op(stream, output, input, floor, ceiling); +} + +template +void power(const Stream& stream, Span output, View input, T exp, T scale, T shift) { + CV_Assert(input.size() == output.size()); + + if (static_cast(exp) == 1.0f) { + scale1_with_bias1(stream, output, input, scale, shift); + return; + } + + generic_op(stream, output, input, exp, scale, shift); +} + #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input); +template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input); +template void tanh<__half>(const Stream&, Span<__half>, View<__half>); +template void swish<__half>(const Stream&, Span<__half>, View<__half>); +template void mish<__half>(const Stream&, Span<__half>, View<__half>); +template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>); +template void bnll<__half>(const Stream&, Span<__half>, View<__half>); +template void elu<__half>(const Stream&, Span<__half>, View<__half>); +template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half); +template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half); +template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half); #endif - template void abs(const Stream& stream, Span output, View input); - template - void launch_vectorized_tanh(const Stream& stream, Span output, View input) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); +template void abs(const Stream& stream, Span output, View input); +template void tanh(const Stream&, Span, View); +template void swish(const Stream&, Span, View); +template void mish(const Stream&, Span, View); +template void sigmoid(const Stream&, Span, View); +template void bnll(const Stream&, Span, View); +template void elu(const Stream&, Span, View); +template void relu(const Stream&, Span, View, float); +template void clipped_relu(const Stream&, Span, View, float, float); +template void power(const Stream&, Span, View, float, float, float); - auto kernel = raw::tanh_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input); - } - - template - void tanh(const Stream& stream, Span output, View input) { - CV_Assert(input.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_tanh(stream, output, input); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_tanh(stream, output, input); - } else { - launch_vectorized_tanh(stream, output, input); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void tanh<__half>(const Stream&, Span<__half>, View<__half>); -#endif - template void tanh(const Stream&, Span, View); - - template - void launch_vectorized_swish(const Stream& stream, Span output, View input) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::swish_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input); - } - - template - void swish(const Stream& stream, Span output, View input) { - CV_Assert(input.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_swish(stream, output, input); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_swish(stream, output, input); - } else { - launch_vectorized_swish(stream, output, input); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void swish<__half>(const Stream&, Span<__half>, View<__half>); -#endif - template void swish(const Stream&, Span, View); - - template - void launch_vectorized_mish(const Stream& stream, Span output, View input) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::mish_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input); - } - - template - void mish(const Stream& stream, Span output, View input) { - CV_Assert(input.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_mish(stream, output, input); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_mish(stream, output, input); - } else { - launch_vectorized_mish(stream, output, input); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void mish<__half>(const Stream&, Span<__half>, View<__half>); -#endif - template void mish(const Stream&, Span, View); - - template - void launch_vectorized_sigmoid(const Stream& stream, Span output, View input) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::sigmoid_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input); - } - - template - void sigmoid(const Stream& stream, Span output, View input) { - CV_Assert(input.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_sigmoid(stream, output, input); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_sigmoid(stream, output, input); - } else { - launch_vectorized_sigmoid(stream, output, input); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>); -#endif - template void sigmoid(const Stream&, Span, View); - - template - void launch_vectorized_bnll(const Stream& stream, Span output, View input) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::bnll_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input); - } - - template - void bnll(const Stream& stream, Span output, View input) { - CV_Assert(input.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_bnll(stream, output, input); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_bnll(stream, output, input); - } else { - launch_vectorized_bnll(stream, output, input); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void bnll<__half>(const Stream&, Span<__half>, View<__half>); -#endif - template void bnll(const Stream&, Span, View); - - template - void launch_vectorized_elu(const Stream& stream, Span output, View input) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::elu_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input); - } - - template - void elu(const Stream& stream, Span output, View input) { - CV_Assert(input.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_elu(stream, output, input); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_elu(stream, output, input); - } else { - launch_vectorized_elu(stream, output, input); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void elu<__half>(const Stream&, Span<__half>, View<__half>); -#endif - template void elu(const Stream&, Span, View); - - template - void launch_vectorized_relu(const Stream& stream, Span output, View input, T slope) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::relu_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input, slope); - } - - template - void relu(const Stream& stream, Span output, View input, T slope) { - CV_Assert(input.size() == output.size()); - - if(is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_relu(stream, output, input, slope); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_relu(stream, output, input, slope); - } else { - launch_vectorized_relu(stream, output, input, slope); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half); -#endif - template void relu(const Stream&, Span, View, float); - - template - void launch_vectorized_clipped_relu(const Stream& stream, Span output, View input, T floor, T ceiling) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::clipped_relu_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input, floor, ceiling); - } - - template - void clipped_relu(const Stream& stream, Span output, View input, T floor, T ceiling) { - CV_Assert(input.size() == output.size()); - CV_Assert(static_cast(floor) <= static_cast(ceiling)); - - if(is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_vectorized_clipped_relu(stream, output, input, floor, ceiling); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_vectorized_clipped_relu(stream, output, input, floor, ceiling); - } else { - launch_vectorized_clipped_relu(stream, output, input, floor, ceiling); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half); -#endif - template void clipped_relu(const Stream&, Span, View, float, float); - - template - void launch_vectorized_axiswise_relu(const Stream& stream, Span output, View input, std::size_t inner_size, View slope) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - CV_Assert(inner_size % N == 0); - - auto kernel = raw::axiswise_relu_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input, inner_size, slope); - } - - template - void axiswise_relu(const Stream& stream, Span output, View input, std::size_t inner_size, View slope) { - CV_Assert(input.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && inner_size % 4 == 0) { - launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && inner_size % 2 == 0) { - launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); - } else { - launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); - } +template static +void launch_vectorized_axiswise_relu(const Stream& stream, Span output, View input, std::size_t inner_size, View slope) { + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(input, N)); + CV_Assert(inner_size % N == 0); + + auto kernel = raw::axiswise_relu_vec; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, input, inner_size, slope); +} + +template +void axiswise_relu(const Stream& stream, Span output, View input, std::size_t inner_size, View slope) { + CV_Assert(input.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && inner_size % 4 == 0) { + launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && inner_size % 2 == 0) { + launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); + } else { + launch_vectorized_axiswise_relu(stream, output, input, inner_size, slope); } +} #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void axiswise_relu<__half>(const Stream&, Span<__half>, View<__half>, std::size_t, View<__half>); #endif template void axiswise_relu(const Stream&, Span, View, std::size_t, View); - template - void launch_vectorized_power(const Stream& stream, Span output, View input, T exp, T scale, T shift) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::power_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input, exp, scale, shift); - } - - template - void power(const Stream& stream, Span output, View input, T exp, T scale, T shift) { - CV_Assert(input.size() == output.size()); - - if (static_cast(exp) == 1.0f) { - scale1_with_bias1(stream, output, input, scale, shift); - return; - } - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4) && output.size()) { - launch_vectorized_power(stream, output, input, exp, scale, shift); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2) && output.size()) { - launch_vectorized_power(stream, output, input, exp, scale, shift); - } else { - launch_vectorized_power(stream, output, input, exp, scale, shift); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half); -#endif - template void power(const Stream&, Span, View, float, float, float); - }}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/bias_activation.cu b/modules/dnn/src/cuda/bias_activation.cu index 6a5229c660..0acc2ff54d 100644 --- a/modules/dnn/src/cuda/bias_activation.cu +++ b/modules/dnn/src/cuda/bias_activation.cu @@ -5,8 +5,8 @@ #include #include +#include "functors.hpp" #include "types.hpp" -#include "math.hpp" #include "vector_traits.hpp" #include "grid_stride_range.hpp" #include "execution.hpp" @@ -20,331 +20,103 @@ using namespace cv::dnn::cuda4dnn::csl::device; namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { namespace raw { - - template - __global__ void biasN_relu_inplace_vec(Span inplace_output, size_type inner_size, View bias, T slope) { + template + __global__ void biasN_generic_op_inplace_vec(Span inplace_output, size_type inner_size, View bias, FunctorArgs ...functorArgs) { using vector_type = get_vector_type_t; auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data()); + Functor functor(functorArgs...); + inner_size /= vector_type::size(); for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) { const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); vector_type vec; v_load(vec, inplace_output_vPtr[i]); - for(int j = 0; j < vec.size(); j++) { - vec.data[j] += bias[bias_idx]; - vec.data[j] = vec.data[j] >= T(0) ? vec.data[j] : slope * vec.data[j]; - } + for(int j = 0; j < vec.size(); j++) + vec.data[j] = functor(vec.data[j] + bias[bias_idx]); v_store(inplace_output_vPtr[i], vec); } } - template - __global__ void biasN_clipped_relu_inplace_vec(Span inplace_output, size_type inner_size, View bias, T floor, T ceil) { - using vector_type = get_vector_type_t; +} /* namespace raw */ - auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data()); - - inner_size /= vector_type::size(); - for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) { - const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); - - vector_type vec; - v_load(vec, inplace_output_vPtr[i]); - for(int j = 0; j < vec.size(); j++) { - using device::clamp; - vec.data[j] = clamp(vec.data[j] + bias[bias_idx], floor, ceil); - } - v_store(inplace_output_vPtr[i], vec); - } - } - - template - __global__ void biasN_power_inplace_vec(Span inplace_output, size_type inner_size, View bias, T power) { - using vector_type = get_vector_type_t; - - auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data()); - - inner_size /= vector_type::size(); - for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) { - const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); - - vector_type vec; - v_load(vec, inplace_output_vPtr[i]); - for(int j = 0; j < vec.size(); j++) { - using device::pow; - vec.data[j] = pow(vec.data[j] + bias[bias_idx], power); - } - v_store(inplace_output_vPtr[i], vec); - } - } - - template - __global__ void biasN_tanh_inplace_vec(Span inplace_output, size_type inner_size, View bias) { - using vector_type = get_vector_type_t; - - auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data()); - - inner_size /= vector_type::size(); - for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) { - const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); - - vector_type vec; - v_load(vec, inplace_output_vPtr[i]); - for(int j = 0; j < vec.size(); j++) { - using device::tanh; - vec.data[j] = tanh(vec.data[j] + bias[bias_idx]); - } - v_store(inplace_output_vPtr[i], vec); - } - } - - template - __global__ void biasN_sigmoid_inplace_vec(Span inplace_output, size_type inner_size, View bias) { - using vector_type = get_vector_type_t; - - auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data()); - - inner_size /= vector_type::size(); - for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) { - const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); - - vector_type vec; - v_load(vec, inplace_output_vPtr[i]); - for(int j = 0; j < vec.size(); j++) { - using device::sigmoid; - vec.data[j] = sigmoid(vec.data[j] + bias[bias_idx]); - } - v_store(inplace_output_vPtr[i], vec); - } - } - - template - __global__ void biasN_swish_inplace_vec(Span inplace_output, size_type inner_size, View bias) { - using vector_type = get_vector_type_t; - - auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data()); - - inner_size /= vector_type::size(); - for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) { - const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); - - vector_type vec; - v_load(vec, inplace_output_vPtr[i]); - for(int j = 0; j < vec.size(); j++) { - using device::sigmoid; - vec.data[j] += bias[bias_idx]; - vec.data[j] = vec.data[j] * sigmoid(vec.data[j]); - } - v_store(inplace_output_vPtr[i], vec); - } - } - - template - __global__ void biasN_mish_inplace_vec(Span inplace_output, size_type inner_size, View bias) { - using vector_type = get_vector_type_t; - - auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data()); - - inner_size /= vector_type::size(); - for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) { - const index_type bias_idx = (i / inner_size) % static_cast(bias.size()); - - vector_type vec; - v_load(vec, inplace_output_vPtr[i]); - for(int j = 0; j < vec.size(); j++) { - using device::tanh; - using device::log1pexp; - vec.data[j] += bias[bias_idx]; - vec.data[j] = vec.data[j] * tanh(log1pexp(vec.data[j])); - } - v_store(inplace_output_vPtr[i], vec); - } - } -} - -template static -void launch_biasN_relu_inplace_vec_kernel(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, T slope) { +template class Activation, std::size_t N, class ...ActivationArgs> static +void launch_vectorized_biasN_generic_op_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, ActivationArgs ...activationArgs) { + CV_Assert(inplace_output.size() % inner_size == 0); + CV_Assert(inplace_output.size() % bias.size() == 0); CV_Assert(is_fully_aligned(inplace_output, N)); CV_Assert(inner_size % N == 0); - auto kernel = raw::biasN_relu_inplace_vec; + auto kernel = raw::biasN_generic_op_inplace_vec, N, ActivationArgs...>; auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream); - launch_kernel(kernel, policy, inplace_output, inner_size, bias, slope); + launch_kernel(kernel, policy, inplace_output, inner_size, bias, activationArgs...); +} + +template class Activation, class ...ActivationArgs> static +void biasN_generic_op_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, ActivationArgs ...activationArgs) { + if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { + launch_vectorized_biasN_generic_op_inplace(stream, inplace_output, inner_size, bias, activationArgs...); + } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { + launch_vectorized_biasN_generic_op_inplace(stream, inplace_output, inner_size, bias, activationArgs...); + } else { + launch_vectorized_biasN_generic_op_inplace(stream, inplace_output, inner_size, bias, activationArgs...); + } } template void biasN_relu_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, T slope) { - if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { - launch_biasN_relu_inplace_vec_kernel(stream, inplace_output, inner_size, bias, slope); - } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { - launch_biasN_relu_inplace_vec_kernel(stream, inplace_output, inner_size, bias, slope); - } else { - launch_biasN_relu_inplace_vec_kernel(stream, inplace_output, inner_size, bias, slope); - } -} - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) -template void biasN_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half); -#endif -template void biasN_relu_inplace(const Stream&, Span, std::size_t, View, float); - -template static -void launch_biasN_clipped_relu_inplace_vec_kernel(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, T floor, T ceil) { - CV_Assert(is_fully_aligned(inplace_output, N)); - CV_Assert(inner_size % N == 0); - - auto kernel = raw::biasN_clipped_relu_inplace_vec; - auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream); - launch_kernel(kernel, policy, inplace_output, inner_size, bias, floor, ceil); + biasN_generic_op_inplace(stream, inplace_output, inner_size, bias, slope); } template void biasN_clipped_relu_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, T floor, T ceil) { - if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { - launch_biasN_clipped_relu_inplace_vec_kernel(stream, inplace_output, inner_size, bias, floor, ceil); - } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { - launch_biasN_clipped_relu_inplace_vec_kernel(stream, inplace_output, inner_size, bias, floor, ceil); - } else { - launch_biasN_clipped_relu_inplace_vec_kernel(stream, inplace_output, inner_size, bias, floor, ceil); - } -} - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) -template void biasN_clipped_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half); -#endif -template void biasN_clipped_relu_inplace(const Stream&, Span, std::size_t, View, float, float); - -template static -void launch_biasN_power_inplace_vec_kernel(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, T power) { - CV_Assert(is_fully_aligned(inplace_output, N)); - CV_Assert(inner_size % N == 0); - - auto kernel = raw::biasN_power_inplace_vec; - auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream); - launch_kernel(kernel, policy, inplace_output, inner_size, bias, power); + CV_Assert(static_cast(floor) <= static_cast(ceil)); + biasN_generic_op_inplace(stream, inplace_output, inner_size, bias, floor, ceil); } template -void biasN_power_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, T power) { - if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { - launch_biasN_power_inplace_vec_kernel(stream, inplace_output, inner_size, bias, power); - } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { - launch_biasN_power_inplace_vec_kernel(stream, inplace_output, inner_size, bias, power); - } else { - launch_biasN_power_inplace_vec_kernel(stream, inplace_output, inner_size, bias, power); - } -} - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) -template void biasN_power_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half); -#endif -template void biasN_power_inplace(const Stream&, Span, std::size_t, View, float); - -template static -void launch_biasN_tanh_inplace_vec_kernel(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - CV_Assert(is_fully_aligned(inplace_output, N)); - CV_Assert(inner_size % N == 0); - - auto kernel = raw::biasN_tanh_inplace_vec; - auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream); - launch_kernel(kernel, policy, inplace_output, inner_size, bias); +void biasN_power_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias, T power, T scale, T shift) { + biasN_generic_op_inplace(stream, inplace_output, inner_size, bias, power, scale, shift); } template void biasN_tanh_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { - launch_biasN_tanh_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { - launch_biasN_tanh_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else { - launch_biasN_tanh_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } -} - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) -template void biasN_tanh_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>); -#endif -template void biasN_tanh_inplace(const Stream&, Span, std::size_t, View); - -template static -void launch_biasN_sigmoid_inplace_vec_kernel(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - CV_Assert(is_fully_aligned(inplace_output, N)); - CV_Assert(inner_size % N == 0); - - auto kernel = raw::biasN_sigmoid_inplace_vec; - auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream); - launch_kernel(kernel, policy, inplace_output, inner_size, bias); + biasN_generic_op_inplace(stream, inplace_output, inner_size, bias); } template void biasN_sigmoid_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { - launch_biasN_sigmoid_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { - launch_biasN_sigmoid_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else { - launch_biasN_sigmoid_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } -} - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) -template void biasN_sigmoid_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>); -#endif -template void biasN_sigmoid_inplace(const Stream&, Span, std::size_t, View); - -template static -void launch_biasN_swish_inplace_vec_kernel(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - CV_Assert(is_fully_aligned(inplace_output, N)); - CV_Assert(inner_size % N == 0); - - auto kernel = raw::biasN_swish_inplace_vec; - auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream); - launch_kernel(kernel, policy, inplace_output, inner_size, bias); + biasN_generic_op_inplace(stream, inplace_output, inner_size, bias); } template void biasN_swish_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { - launch_biasN_swish_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { - launch_biasN_swish_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else { - launch_biasN_swish_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } -} - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) -template void biasN_swish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>); -#endif -template void biasN_swish_inplace(const Stream&, Span, std::size_t, View); - -template static -void launch_biasN_mish_inplace_vec_kernel(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - CV_Assert(is_fully_aligned(inplace_output, N)); - CV_Assert(inner_size % N == 0); - - auto kernel = raw::biasN_mish_inplace_vec; - auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream); - launch_kernel(kernel, policy, inplace_output, inner_size, bias); + biasN_generic_op_inplace(stream, inplace_output, inner_size, bias); } template void biasN_mish_inplace(const Stream& stream, Span inplace_output, std::size_t inner_size, View bias) { - if (is_fully_aligned(inplace_output, 4) && inner_size % 4 == 0) { - launch_biasN_mish_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else if (is_fully_aligned(inplace_output, 2) && inner_size % 2 == 0) { - launch_biasN_mish_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } else { - launch_biasN_mish_inplace_vec_kernel(stream, inplace_output, inner_size, bias); - } + biasN_generic_op_inplace(stream, inplace_output, inner_size, bias); } #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) +template void biasN_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half); +template void biasN_clipped_relu_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half); +template void biasN_power_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>, __half, __half, __half); +template void biasN_tanh_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>); +template void biasN_sigmoid_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>); +template void biasN_swish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>); template void biasN_mish_inplace<__half>(const Stream&, Span<__half>, std::size_t, View<__half>); #endif + +template void biasN_relu_inplace(const Stream&, Span, std::size_t, View, float); +template void biasN_clipped_relu_inplace(const Stream&, Span, std::size_t, View, float, float); +template void biasN_power_inplace(const Stream&, Span, std::size_t, View, float, float, float); +template void biasN_tanh_inplace(const Stream&, Span, std::size_t, View); +template void biasN_sigmoid_inplace(const Stream&, Span, std::size_t, View); +template void biasN_swish_inplace(const Stream&, Span, std::size_t, View); template void biasN_mish_inplace(const Stream&, Span, std::size_t, View); }}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/eltwise_ops.cu b/modules/dnn/src/cuda/eltwise_ops.cu index 521bb4351b..a7d06e63a1 100644 --- a/modules/dnn/src/cuda/eltwise_ops.cu +++ b/modules/dnn/src/cuda/eltwise_ops.cu @@ -5,7 +5,7 @@ #include #include -#include "math.hpp" +#include "functors.hpp" #include "grid_stride_range.hpp" #include "execution.hpp" #include "vector_traits.hpp" @@ -20,263 +20,91 @@ using namespace cv::dnn::cuda4dnn::csl::device; namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { - namespace raw { - template - __global__ void eltwise_max_2_vec(Span output, View x, View y) { - using vector_type = get_vector_type_t; +namespace raw { + template + __global__ void eltwise_op_vec(Span output, View x, View y, FunctorArgs ...functorArgs) { + using vector_type = get_vector_type_t; - auto output_vPtr = vector_type::get_pointer(output.data()); - auto x_vPtr = vector_type::get_pointer(x.data()); - auto y_vPtr = vector_type::get_pointer(y.data()); + auto output_vPtr = vector_type::get_pointer(output.data()); + auto x_vPtr = vector_type::get_pointer(x.data()); + auto y_vPtr = vector_type::get_pointer(y.data()); - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec_x, vec_y; - v_load(vec_x, x_vPtr[i]); - v_load(vec_y, y_vPtr[i]); + Functor functor(functorArgs...); - for (int j = 0; j < vector_type::size(); j++) { - using device::max; - vec_x.data[j] = max(vec_x.data[j], vec_y.data[j]); - } - - v_store(output_vPtr[i], vec_x); - } - } - - template - __global__ void eltwise_sum_2_vec(Span output, View x, View y) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto x_vPtr = vector_type::get_pointer(x.data()); - auto y_vPtr = vector_type::get_pointer(y.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec_x, vec_y; - v_load(vec_x, x_vPtr[i]); - v_load(vec_y, y_vPtr[i]); - - for (int j = 0; j < vector_type::size(); j++) - vec_x.data[j] = vec_x.data[j] + vec_y.data[j]; - - v_store(output_vPtr[i], vec_x); - } - } - - template - __global__ void eltwise_sum_coeff_2_vec(Span output, T coeff_x, View x, T coeff_y, View y) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto x_vPtr = vector_type::get_pointer(x.data()); - auto y_vPtr = vector_type::get_pointer(y.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec_x, vec_y; - v_load(vec_x, x_vPtr[i]); - v_load(vec_y, y_vPtr[i]); - - for (int j = 0; j < vector_type::size(); j++) - vec_x.data[j] = coeff_x * vec_x.data[j] + coeff_y * vec_y.data[j]; - - v_store(output_vPtr[i], vec_x); - } - } - - template - __global__ void eltwise_prod_2_vec(Span output, View x, View y) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto x_vPtr = vector_type::get_pointer(x.data()); - auto y_vPtr = vector_type::get_pointer(y.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec_x, vec_y; - v_load(vec_x, x_vPtr[i]); - v_load(vec_y, y_vPtr[i]); - - for (int j = 0; j < vector_type::size(); j++) - vec_x.data[j] = vec_x.data[j] * vec_y.data[j]; - - v_store(output_vPtr[i], vec_x); - } - } - - template - __global__ void eltwise_div_2_vec(Span output, View x, View y) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto x_vPtr = vector_type::get_pointer(x.data()); - auto y_vPtr = vector_type::get_pointer(y.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec_x, vec_y; - v_load(vec_x, x_vPtr[i]); - v_load(vec_y, y_vPtr[i]); - - for (int j = 0; j < vector_type::size(); j++) - vec_x.data[j] = vec_x.data[j] / vec_y.data[j]; - - v_store(output_vPtr[i], vec_x); - } + for (auto i : grid_stride_range(output.size() / vector_type::size())) { + vector_type vec_x, vec_y; + v_load(vec_x, x_vPtr[i]); + v_load(vec_y, y_vPtr[i]); + for (int j = 0; j < vector_type::size(); j++) + vec_x.data[j] = functor(vec_x.data[j], vec_y.data[j]); + v_store(output_vPtr[i], vec_x); } } +} - template - void launch_vectorized_eltwise_max_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(x, N)); - CV_Assert(is_fully_aligned(y, N)); +template class EltwiseOp, std::size_t N, class ...EltwiseOpArgs> static +void launch_vectorized_eltwise_op(const Stream& stream, Span output, View x, View y, EltwiseOpArgs ...eltwiseOpArgs) { + CV_Assert(x.size() == y.size()); + CV_Assert(x.size() == output.size()); + CV_Assert(is_fully_aligned(output, N)); + CV_Assert(is_fully_aligned(x, N)); + CV_Assert(is_fully_aligned(y, N)); - auto kernel = raw::eltwise_max_2_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, x, y); + auto kernel = raw::eltwise_op_vec, N, EltwiseOpArgs...>; + auto policy = make_policy(kernel, output.size() / N, 0, stream); + launch_kernel(kernel, policy, output, x, y, eltwiseOpArgs...); +} + +template class EltwiseOp, class ...EltwiseOpArgs> static +void eltwise_op(const Stream& stream, Span output, View x, View y, EltwiseOpArgs ...eltwiseOpArgs) { + CV_Assert(x.size() == y.size()); + CV_Assert(x.size() == output.size()); + + if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { + launch_vectorized_eltwise_op(stream, output, x, y, eltwiseOpArgs...); + } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { + launch_vectorized_eltwise_op(stream, output, x, y, eltwiseOpArgs...); + } else { + launch_vectorized_eltwise_op(stream, output, x, y, eltwiseOpArgs...); } +} - template - void eltwise_max_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(x.size() == y.size()); - CV_Assert(x.size() == output.size()); +template +void eltwise_max_2(const Stream& stream, Span output, View x, View y) { + eltwise_op(stream, output, x, y); +} - if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { - launch_vectorized_eltwise_max_2(stream, output, x, y); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { - launch_vectorized_eltwise_max_2(stream, output, x, y); - } else { - launch_vectorized_eltwise_max_2(stream, output, x, y); - } - } +template +void eltwise_sum_2(const Stream& stream, Span output, View x, View y) { + eltwise_op(stream, output, x, y); +} -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void eltwise_max_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); -#endif - template void eltwise_max_2(const Stream& stream, Span output, View x, View y); +template +void eltwise_sum_coeff_2(const Stream& stream, Span output, T coeff_x, View x, T coeff_y, View y) { + eltwise_op(stream, output, x, y, coeff_x, coeff_y); +} - template - void launch_vectorized_eltwise_sum_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(x, N)); - CV_Assert(is_fully_aligned(y, N)); +template +void eltwise_prod_2(const Stream& stream, Span output, View x, View y) { + eltwise_op(stream, output, x, y); +} - auto kernel = raw::eltwise_sum_2_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, x, y); - } - - template - void eltwise_sum_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(x.size() == y.size()); - CV_Assert(x.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { - launch_vectorized_eltwise_sum_2(stream, output, x, y); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { - launch_vectorized_eltwise_sum_2(stream, output, x, y); - } else { - launch_vectorized_eltwise_sum_2(stream, output, x, y); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void eltwise_sum_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); -#endif - template void eltwise_sum_2(const Stream& stream, Span output, View x, View y); - - template - void launch_vectorized_eltwise_sum_coeff_2(const Stream& stream, Span output, T coeff_x, View x, T coeff_y, View y) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(x, N)); - CV_Assert(is_fully_aligned(y, N)); - - auto kernel = raw::eltwise_sum_coeff_2_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, coeff_x, x, coeff_y, y); - } - - template - void eltwise_sum_coeff_2(const Stream& stream, Span output, T coeff_x, View x, T coeff_y, View y) { - CV_Assert(x.size() == y.size()); - CV_Assert(x.size() == output.size()); - - if (static_cast(coeff_x) == 1.0f && static_cast(coeff_y) == 1.0f) { - eltwise_sum_2(stream, output, x, y); - return; - } - - if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { - launch_vectorized_eltwise_sum_coeff_2(stream, output, coeff_x, x, coeff_y, y); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { - launch_vectorized_eltwise_sum_coeff_2(stream, output, coeff_x, x, coeff_y, y); - } else { - launch_vectorized_eltwise_sum_coeff_2(stream, output, coeff_x, x, coeff_y, y); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void eltwise_sum_coeff_2(const Stream&, Span<__half>, __half, View<__half>, __half, View<__half>); -#endif - template void eltwise_sum_coeff_2(const Stream&, Span, float, View, float, View); - - template - void launch_vectorized_eltwise_prod_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(x, N)); - CV_Assert(is_fully_aligned(y, N)); - - auto kernel = raw::eltwise_prod_2_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, x, y); - } - - template - void eltwise_prod_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(x.size() == y.size()); - CV_Assert(x.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { - launch_vectorized_eltwise_prod_2(stream, output, x, y); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { - launch_vectorized_eltwise_prod_2(stream, output, x, y); - } else { - launch_vectorized_eltwise_prod_2(stream, output, x, y); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void eltwise_prod_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); -#endif - template void eltwise_prod_2(const Stream& stream, Span output, View x, View y); - - template - void launch_vectorized_eltwise_div_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(x, N)); - CV_Assert(is_fully_aligned(y, N)); - - auto kernel = raw::eltwise_div_2_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, x, y); - } - - template - void eltwise_div_2(const Stream& stream, Span output, View x, View y) { - CV_Assert(x.size() == y.size()); - CV_Assert(x.size() == output.size()); - - if (is_fully_aligned(output, 4) && is_fully_aligned(x, 4) && is_fully_aligned(y, 4)) { - launch_vectorized_eltwise_div_2(stream, output, x, y); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(x, 2) && is_fully_aligned(y, 2)) { - launch_vectorized_eltwise_div_2(stream, output, x, y); - } else { - launch_vectorized_eltwise_div_2(stream, output, x, y); - } - } +template +void eltwise_div_2(const Stream& stream, Span output, View x, View y) { + eltwise_op(stream, output, x, y); +} #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void eltwise_div_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); + template void eltwise_prod_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); + template void eltwise_sum_coeff_2(const Stream&, Span<__half>, __half, View<__half>, __half, View<__half>); + template void eltwise_sum_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); + template void eltwise_max_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y); #endif template void eltwise_div_2(const Stream& stream, Span output, View x, View y); + template void eltwise_prod_2(const Stream& stream, Span output, View x, View y); + template void eltwise_sum_coeff_2(const Stream&, Span, float, View, float, View); + template void eltwise_sum_2(const Stream& stream, Span output, View x, View y); + template void eltwise_max_2(const Stream& stream, Span output, View x, View y); }}}} /* namespace cv::dnn::cuda4dnn::kernels */ diff --git a/modules/dnn/src/cuda/execution.hpp b/modules/dnn/src/cuda/execution.hpp index 57d1e302b9..27b86ef281 100644 --- a/modules/dnn/src/cuda/execution.hpp +++ b/modules/dnn/src/cuda/execution.hpp @@ -63,17 +63,17 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { template inline void launch_kernel(Kernel kernel, Args ...args) { auto policy = make_policy(kernel); - kernel <<>> (std::forward(args)...); + kernel <<>> (args...); } template inline void launch_kernel(Kernel kernel, dim3 grid, dim3 block, Args ...args) { - kernel <<>> (std::forward(args)...); + kernel <<>> (args...); } template inline void launch_kernel(Kernel kernel, execution_policy policy, Args ...args) { - kernel <<>> (std::forward(args)...); + kernel <<>> (args...); } }}}} /* namespace cv::dnn::cuda4dnn::csl */ diff --git a/modules/dnn/src/cuda/functors.hpp b/modules/dnn/src/cuda/functors.hpp new file mode 100644 index 0000000000..c35a85437c --- /dev/null +++ b/modules/dnn/src/cuda/functors.hpp @@ -0,0 +1,139 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP +#define OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP + +#include + +#include "math.hpp" + +namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { + +template +struct abs_functor { + __device__ T operator()(T value) { + using csl::device::abs; + return abs(value); + } +}; + +template +struct tanh_functor { + __device__ T operator()(T value) { + using csl::device::tanh; + return tanh(value); + } +}; + +template +struct swish_functor { + __device__ T operator()(T value) { + using csl::device::sigmoid; + return value * sigmoid(value); + } +}; + +template +struct mish_functor { + __device__ T operator()(T value) { + using csl::device::tanh; + using csl::device::log1pexp; + return value * tanh(log1pexp(value)); + } +}; + +template +struct sigmoid_functor { + __device__ T operator()(T value) { + using csl::device::sigmoid; + return sigmoid(value); + } +}; + +template +struct bnll_functor { + __device__ T operator()(T value) { + using csl::device::log1pexp; + return value > T(0) ? value + log1pexp(-value) : log1pexp(value); + } +}; + +template +struct elu_functor { + __device__ T operator()(T value) { + using csl::device::expm1; + return value >= T(0) ? value : expm1(value); + } +}; + +template +struct relu_functor { + __device__ relu_functor(T slope_) : slope{slope_} { } + __device__ T operator()(T value) { + using csl::device::log1pexp; + return value >= T(0) ? value : slope * value; + } + + T slope; +}; + +template +struct clipped_relu_functor { + __device__ clipped_relu_functor(T floor_, T ceiling_) : floor{floor_}, ceiling{ceiling_} { } + __device__ T operator()(T value) { + using csl::device::clamp; + return clamp(value, floor, ceiling); + } + + T floor, ceiling; +}; + +template +struct power_functor { + __device__ power_functor(T exp_, T scale_, T shift_) : exp{exp_}, scale{scale_}, shift{shift_} { } + __device__ T operator()(T value) { + using csl::device::pow; + return pow(shift + scale * value, exp); + } + + T exp, scale, shift; +}; + +template +struct max_functor { + __device__ T operator()(T x, T y) { + using csl::device::max; + return max(x, y); + } +}; + +template +struct sum_functor { + __device__ T operator()(T x, T y) { return x + y; } +}; + +template +struct scaled_sum_functor { + __device__ scaled_sum_functor(T scale_x_, T scale_y_) + : scale_x{scale_x_}, scale_y{scale_y_} { } + + __device__ T operator()(T x, T y) { return scale_x * x + scale_y * y; } + + T scale_x, scale_y; +}; + +template +struct product_functor { + __device__ T operator()(T x, T y) { return x * y; } +}; + +template +struct div_functor { + __device__ T operator()(T x, T y) { return x / y; } +}; + +}}}} /* namespace cv::dnn::cuda4dnn::kernels */ + +#endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */ \ No newline at end of file diff --git a/modules/dnn/src/cuda/scale_shift.cu b/modules/dnn/src/cuda/scale_shift.cu index 31fa471b53..36bdb7a261 100644 --- a/modules/dnn/src/cuda/scale_shift.cu +++ b/modules/dnn/src/cuda/scale_shift.cu @@ -24,22 +24,6 @@ using namespace cv::dnn::cuda4dnn::csl::device; namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { namespace raw { - template - __global__ void bias1_vec(Span output, View input, T beta) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vec.size(); j++) - vec.data[j] = vec.data[j] + beta; - v_store(output_vPtr[i], vec); - } - } - template __global__ void biasN_vec(Span output, View input, size_type inner_size, View bias) { using vector_type = get_vector_type_t; @@ -59,22 +43,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { } } - template - __global__ void scale1_vec(Span output, View input, T alpha) { - using vector_type = get_vector_type_t; - - auto output_vPtr = vector_type::get_pointer(output.data()); - auto input_vPtr = vector_type::get_pointer(input.data()); - - for (auto i : grid_stride_range(output.size() / vector_type::size())) { - vector_type vec; - v_load(vec, input_vPtr[i]); - for (int j = 0; j < vec.size(); j++) - vec.data[j] = vec.data[j] * alpha; - v_store(output_vPtr[i], vec); - } - } - template __global__ void scaleN_vec(Span output, View input, size_type inner_size, View weights) { @@ -133,34 +101,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { } } - template static - void launch_bias1_vec_kernel(const Stream& stream, Span output, View input, T beta) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::bias1_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input, beta); - } - - template - void bias1(const Stream& stream, TensorSpan output, TensorView input, T beta) { - CV_Assert(is_shape_same(input, output)); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_bias1_vec_kernel(stream, output, input, beta); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_bias1_vec_kernel(stream, output, input, beta); - } else { - launch_bias1_vec_kernel(stream, output, input, beta); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void bias1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half); -#endif - template void bias1(const Stream&, TensorSpan, TensorView, float); - template static void launch_biasN_vec_kernel(const Stream& stream, Span output, View input, std::size_t inner_size, View bias){ CV_Assert(is_fully_aligned(output, N)); @@ -195,34 +135,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { #endif template void biasN(const Stream&, TensorSpan, TensorView, std::size_t, TensorView); - template static - void launch_scale1_vec_kernel(const Stream& stream, Span output, View input, T alpha) { - CV_Assert(is_fully_aligned(output, N)); - CV_Assert(is_fully_aligned(input, N)); - - auto kernel = raw::scale1_vec; - auto policy = make_policy(kernel, output.size() / N, 0, stream); - launch_kernel(kernel, policy, output, input, alpha); - } - - template - void scale1(const Stream& stream, TensorSpan output, TensorView input, T alpha) { - CV_Assert(is_shape_same(input, output)); - - if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { - launch_scale1_vec_kernel(stream, output, input, alpha); - } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { - launch_scale1_vec_kernel(stream, output, input, alpha); - } else { - launch_scale1_vec_kernel(stream, output, input, alpha); - } - } - -#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) - template void scale1<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, __half); -#endif - template void scale1(const Stream&, TensorSpan, TensorView, float); - template static void launch_scaleN_vec_kernel(const Stream& stream, Span output, View input, std::size_t inner_size, View weights) { CV_Assert(is_fully_aligned(output, N)); diff --git a/modules/dnn/src/cuda4dnn/kernels/bias_activation.hpp b/modules/dnn/src/cuda4dnn/kernels/bias_activation.hpp index 93660a8c33..500f9bb567 100644 --- a/modules/dnn/src/cuda4dnn/kernels/bias_activation.hpp +++ b/modules/dnn/src/cuda4dnn/kernels/bias_activation.hpp @@ -19,7 +19,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { void biasN_clipped_relu_inplace(const csl::Stream& stream, csl::Span inplace_output, std::size_t inner_size, csl::View bias, T floor, T ceiling); template - void biasN_power_inplace(const csl::Stream& stream, csl::Span inplace_output, std::size_t inner_size, csl::View bias, T exp); + void biasN_power_inplace(const csl::Stream& stream, csl::Span inplace_output, std::size_t inner_size, csl::View bias, T exp, T scale, T shift); template void biasN_tanh_inplace(const csl::Stream& stream, csl::Span inplace_output, std::size_t inner_size, csl::View bias); diff --git a/modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp b/modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp index 32fa1d8b72..7b7da3bc92 100644 --- a/modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp +++ b/modules/dnn/src/cuda4dnn/kernels/scale_shift.hpp @@ -12,18 +12,12 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { - template - void bias1(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, T alpha); - template void biasN(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, std::size_t inner_size, csl::TensorView bias); - template - void scale1(const csl::Stream& stream, csl::TensorSpan output, csl::TensorView input, T alpha); - template void scaleN(const csl::Stream& stream, csl::TensorSpan output, diff --git a/modules/dnn/src/cuda4dnn/primitives/convolution.hpp b/modules/dnn/src/cuda4dnn/primitives/convolution.hpp index 0a0050bd85..b0039525ae 100644 --- a/modules/dnn/src/cuda4dnn/primitives/convolution.hpp +++ b/modules/dnn/src/cuda4dnn/primitives/convolution.hpp @@ -286,7 +286,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { kernels::biasN_clipped_relu_inplace(stream, output, inner_size, biasTensor, crelu_floor, crelu_ceil); break; case ConvolutionConfiguration::ActivationType::POWER: - kernels::biasN_power_inplace(stream, output, inner_size, biasTensor, power_exp); + kernels::biasN_power_inplace(stream, output, inner_size, biasTensor, power_exp, T(1.0), T(0.0)); break; case ConvolutionConfiguration::ActivationType::TANH: kernels::biasN_tanh_inplace(stream, output, inner_size, biasTensor); diff --git a/modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp b/modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp index ecef608647..f067dddaa7 100644 --- a/modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp +++ b/modules/dnn/src/cuda4dnn/primitives/normalize_bbox.hpp @@ -113,7 +113,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { */ if (weight != 1.0) { - kernels::scale1(stream, output, input, weight); + kernels::scale1_with_bias1(stream, output, input, weight, 1.0); } else if (!weightsTensor.empty()) {