Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
// This file is part of OpenCV project.
|
|
|
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
|
|
|
// of this distribution and at http://opencv.org/license.html.
|
|
|
|
|
|
|
|
#ifndef OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP
|
|
|
|
#define OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP
|
|
|
|
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
|
|
|
|
#include "types.hpp"
|
2020-05-10 01:20:30 +08:00
|
|
|
#include "memory.hpp"
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
|
|
|
#include "../cuda4dnn/csl/pointer.hpp"
|
|
|
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
|
|
|
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
|
|
|
|
|
|
|
|
/** \file vector_traits.hpp
|
|
|
|
* \brief utility classes and functions for vectorized memory loads/stores
|
|
|
|
*
|
|
|
|
* Example:
|
|
|
|
* using vector_type = get_vector_type_t<float, 4>;
|
|
|
|
*
|
|
|
|
* auto input_vPtr = type::get_pointer(iptr); // iptr is of type DevicePtr<const float>
|
|
|
|
* auto output_vPtr = type::get_pointer(optr); // optr is of type DevicePtr<float>
|
|
|
|
*
|
|
|
|
* vector_type vec;
|
|
|
|
* v_load(vec, input_vPtr);
|
|
|
|
*
|
|
|
|
* for(int i = 0; i < vector_type::size(); i++)
|
|
|
|
* vec[i] = do_something(vec[i]);
|
|
|
|
*
|
|
|
|
* v_store(output_vPtr, vec);
|
|
|
|
*/
|
|
|
|
|
|
|
|
namespace detail {
|
|
|
|
template <size_type N> struct raw_type_ { };
|
|
|
|
template <> struct raw_type_<256> { typedef ulonglong4 type; };
|
|
|
|
template <> struct raw_type_<128> { typedef uint4 type; };
|
|
|
|
template <> struct raw_type_<64> { typedef uint2 type; };
|
|
|
|
template <> struct raw_type_<32> { typedef uint1 type; };
|
|
|
|
template <> struct raw_type_<16> { typedef uchar2 type; };
|
|
|
|
template <> struct raw_type_<8> { typedef uchar1 type; };
|
|
|
|
|
|
|
|
template <size_type N> struct raw_type {
|
|
|
|
using type = typename raw_type_<N>::type;
|
|
|
|
static_assert(sizeof(type) * 8 == N, "");
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
/* \tparam T type of element in the vector
|
|
|
|
* \tparam N "number of elements" of type T in the vector
|
|
|
|
*/
|
|
|
|
template <class T, size_type N>
|
|
|
|
union vector_type {
|
|
|
|
using value_type = T;
|
|
|
|
using raw_type = typename detail::raw_type<N * sizeof(T) * 8>::type;
|
|
|
|
|
|
|
|
__device__ vector_type() { }
|
|
|
|
|
|
|
|
__device__ static constexpr size_type size() { return N; }
|
|
|
|
|
|
|
|
raw_type raw;
|
|
|
|
T data[N];
|
|
|
|
|
|
|
|
template <class U> static __device__
|
|
|
|
typename std::enable_if<std::is_const<U>::value, const vector_type*>
|
|
|
|
::type get_pointer(csl::DevicePtr<U> ptr) {
|
|
|
|
return reinterpret_cast<const vector_type*>(ptr.get());
|
|
|
|
}
|
|
|
|
|
|
|
|
template <class U> static __device__
|
|
|
|
typename std::enable_if<!std::is_const<U>::value, vector_type*>
|
|
|
|
::type get_pointer(csl::DevicePtr<U> ptr) {
|
|
|
|
return reinterpret_cast<vector_type*>(ptr.get());
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <class V>
|
|
|
|
__device__ void v_load(V& dest, const V& src) {
|
|
|
|
dest.raw = src.raw;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <class V>
|
|
|
|
__device__ void v_load(V& dest, const V* src) {
|
|
|
|
dest.raw = src->raw;
|
|
|
|
}
|
|
|
|
|
2020-05-10 01:20:30 +08:00
|
|
|
template <class V>
|
|
|
|
__device__ void v_load_ldg(V& dest, const V& src) {
|
|
|
|
dest.raw = load_ldg(src.raw);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <class V>
|
|
|
|
__device__ void v_load_ldg(V& dest, const V* src) {
|
|
|
|
dest.raw = load_ldg(src->raw);
|
|
|
|
}
|
|
|
|
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
template <class V>
|
|
|
|
__device__ void v_store(V* dest, const V& src) {
|
|
|
|
dest->raw = src.raw;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <class V>
|
|
|
|
__device__ void v_store(V& dest, const V& src) {
|
|
|
|
dest.raw = src.raw;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <class T, size_type N>
|
|
|
|
struct get_vector_type {
|
|
|
|
typedef vector_type<T, N> type;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <class T, size_type N>
|
|
|
|
using get_vector_type_t = typename get_vector_type<T, N>::type;
|
|
|
|
|
|
|
|
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
|
|
|
|
|
|
|
|
#endif /* OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP */
|