// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. #include #include #include "grid_stride_range.hpp" #include "execution.hpp" #include "vector_traits.hpp" #include "../cuda4dnn/csl/stream.hpp" #include "../cuda4dnn/csl/span.hpp" using namespace cv::dnn::cuda4dnn::csl; using namespace cv::dnn::cuda4dnn::csl::device; namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { namespace raw { template __global__ void fill_vec(Span output, T value) { using vector_type = get_vector_type_t; auto output_vPtr = vector_type::get_pointer(output.data()); for (auto i : grid_stride_range(output.size() / vector_type::size())) { vector_type vec; for (int j = 0; j < vector_type::size(); j++) vec.data[j] = value; v_store(output_vPtr[i], vec); } } template __global__ void copy_vec(Span output, View input) { using vector_type = get_vector_type_t; auto input_vPtr = vector_type::get_pointer(input.data()); auto output_vPtr = vector_type::get_pointer(output.data()); for (auto i : grid_stride_range(output.size() / vector_type::size())) { vector_type vec; v_load(vec, input_vPtr[i]); v_store(output_vPtr[i], vec); } } } template static void launch_vectorized_fill(const Stream& stream, Span output, T value) { CV_Assert(is_fully_aligned(output, N)); auto kernel = raw::fill_vec; auto policy = make_policy(kernel, output.size() / N, 0, stream); launch_kernel(kernel, policy, output, value); } template void fill(const Stream& stream, Span output, T value) { if (is_fully_aligned(output, 4)) { launch_vectorized_fill(stream, output, value); } else if (is_fully_aligned(output, 2)) { launch_vectorized_fill(stream, output, value); } else { launch_vectorized_fill(stream, output, value); } } #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void fill(const Stream&, Span<__half>, __half); #endif template void fill(const Stream&, Span, float); template static void launch_vectorized_copy(const Stream& stream, Span output, View input) { CV_Assert(is_fully_aligned(output, N)); CV_Assert(is_fully_aligned(input, N)); auto kernel = raw::copy_vec; auto policy = make_policy(kernel, output.size() / N, 0, stream); launch_kernel(kernel, policy, output, input); } template void copy(const Stream& stream, Span output, View input) { if (is_fully_aligned(output, 4) && is_fully_aligned(input, 4)) { launch_vectorized_copy(stream, output, input); } else if (is_fully_aligned(output, 2) && is_fully_aligned(input, 2)) { launch_vectorized_copy(stream, output, input); } else { launch_vectorized_copy(stream, output, input); } } #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void copy(const Stream&, Span<__half>, View<__half>); #endif template void copy(const Stream&, Span, View); }}}} /* namespace cv::dnn::cuda4dnn::kernels */