Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
// This file is part of OpenCV project.
|
|
|
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
|
|
|
// of this distribution and at http://opencv.org/license.html.
|
|
|
|
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
|
2021-10-04 15:08:45 +08:00
|
|
|
#include "array.hpp"
|
2020-02-29 16:46:14 +08:00
|
|
|
#include "functors.hpp"
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
#include "grid_stride_range.hpp"
|
|
|
|
#include "execution.hpp"
|
|
|
|
#include "vector_traits.hpp"
|
2021-10-04 15:08:45 +08:00
|
|
|
#include "kernel_dispatcher.hpp"
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
|
|
|
#include "../cuda4dnn/csl/stream.hpp"
|
|
|
|
#include "../cuda4dnn/csl/span.hpp"
|
2021-10-04 15:08:45 +08:00
|
|
|
#include "../cuda4dnn/csl/tensor.hpp"
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
|
|
|
#include <opencv2/core.hpp>
|
|
|
|
|
|
|
|
using namespace cv::dnn::cuda4dnn::csl;
|
|
|
|
using namespace cv::dnn::cuda4dnn::csl::device;
|
|
|
|
|
|
|
|
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
namespace raw {
|
2020-07-09 21:02:21 +08:00
|
|
|
template <class T, class EltwiseOp, std::size_t N>
|
|
|
|
__global__ void eltwise_op_vec(Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params params) {
|
2020-02-29 16:46:14 +08:00
|
|
|
using vector_type = get_vector_type_t<T, N>;
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
auto output_vPtr = vector_type::get_pointer(output.data());
|
|
|
|
auto x_vPtr = vector_type::get_pointer(x.data());
|
|
|
|
auto y_vPtr = vector_type::get_pointer(y.data());
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
2020-07-09 21:02:21 +08:00
|
|
|
EltwiseOp eltwise_op(params);
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
|
|
|
|
vector_type vec_x, vec_y;
|
|
|
|
v_load(vec_x, x_vPtr[i]);
|
|
|
|
v_load(vec_y, y_vPtr[i]);
|
|
|
|
for (int j = 0; j < vector_type::size(); j++)
|
2020-07-09 21:02:21 +08:00
|
|
|
vec_x.data[j] = eltwise_op(vec_x.data[j], vec_y.data[j]);
|
2020-02-29 16:46:14 +08:00
|
|
|
v_store(output_vPtr[i], vec_x);
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
}
|
|
|
|
}
|
2021-10-04 15:08:45 +08:00
|
|
|
|
|
|
|
template <class T, class EltwiseOp, std::size_t Rank>
|
|
|
|
__global__ void eltwise_op_bcast(
|
|
|
|
Span<T> output, array<size_type, Rank> out_strides,
|
|
|
|
View<T> x, array<size_type, Rank> x_strides, array<bool, Rank> x_bcast,
|
|
|
|
View<T> y, array<size_type, Rank> y_strides, array<bool, Rank> y_bcast,
|
|
|
|
const typename EltwiseOp::Params params) {
|
|
|
|
EltwiseOp eltwise_op(params);
|
|
|
|
|
|
|
|
for (auto i : grid_stride_range(output.size())) {
|
|
|
|
index_type out_index = i / out_strides[0];
|
|
|
|
index_type x_index = x_bcast[0] ? 0 : out_index * x_strides[0];
|
|
|
|
index_type y_index = y_bcast[0] ? 0 : out_index * y_strides[0];
|
|
|
|
|
|
|
|
for (int j = 1; j < Rank; j++)
|
|
|
|
{
|
|
|
|
out_index = (i % out_strides[j - 1]) / out_strides[j];
|
|
|
|
if (!x_bcast[j])
|
|
|
|
x_index += out_index * x_strides[j];
|
|
|
|
if (!y_bcast[j])
|
|
|
|
y_index += out_index * y_strides[j];
|
|
|
|
}
|
|
|
|
|
|
|
|
output[i] = eltwise_op(x[x_index], y[y_index]);
|
|
|
|
}
|
|
|
|
}
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
|
|
|
|
2020-07-09 21:02:21 +08:00
|
|
|
template <class T, class EltwiseOp, std::size_t N> static
|
|
|
|
void launch_vectorized_eltwise_op(const Stream& stream, Span<T> output, View<T> x, View<T> y, const typename EltwiseOp::Params& params) {
|
2020-02-29 16:46:14 +08:00
|
|
|
CV_Assert(x.size() == y.size());
|
|
|
|
CV_Assert(x.size() == output.size());
|
|
|
|
CV_Assert(is_fully_aligned<T>(output, N));
|
|
|
|
CV_Assert(is_fully_aligned<T>(x, N));
|
|
|
|
CV_Assert(is_fully_aligned<T>(y, N));
|
|
|
|
|
2020-07-09 21:02:21 +08:00
|
|
|
auto kernel = raw::eltwise_op_vec<T, EltwiseOp, N>;
|
2020-02-29 16:46:14 +08:00
|
|
|
auto policy = make_policy(kernel, output.size() / N, 0, stream);
|
2020-07-09 21:02:21 +08:00
|
|
|
launch_kernel(kernel, policy, output, x, y, params);
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
|
|
|
|
2021-10-04 15:08:45 +08:00
|
|
|
template <class T, class EltwiseOp, std::size_t Rank> static
|
|
|
|
void launch_eltwise_op_bcast(
|
|
|
|
const Stream& stream,
|
|
|
|
Span<T> output, const std::vector<std::size_t>& outStride,
|
|
|
|
View<T> x, const std::vector<std::size_t>& inStride1, const std::vector<int>& inBcast1,
|
|
|
|
View<T> y, const std::vector<std::size_t>& inStride2, const std::vector<int>& inBcast2,
|
|
|
|
const typename EltwiseOp::Params& params)
|
|
|
|
{
|
|
|
|
CV_Assert(outStride.size() == Rank);
|
|
|
|
CV_Assert(inStride1.size() == Rank);
|
|
|
|
CV_Assert(inStride2.size() == Rank);
|
|
|
|
CV_Assert(inBcast1.size() == Rank);
|
|
|
|
CV_Assert(inBcast2.size() == Rank);
|
|
|
|
|
|
|
|
array<size_type, Rank> outStride_k, inStride1_k, inStride2_k;
|
|
|
|
outStride_k.assign(std::begin(outStride), std::end(outStride));
|
|
|
|
inStride1_k.assign(std::begin(inStride1), std::end(inStride1));
|
|
|
|
inStride2_k.assign(std::begin(inStride2), std::end(inStride2));
|
|
|
|
|
|
|
|
array<bool, Rank> inBcast1_k, inBcast2_k;
|
|
|
|
inBcast1_k.assign(std::begin(inBcast1), std::end(inBcast1));
|
|
|
|
inBcast2_k.assign(std::begin(inBcast2), std::end(inBcast2));
|
|
|
|
|
|
|
|
auto kernel = raw::eltwise_op_bcast<T, EltwiseOp, Rank>;
|
|
|
|
auto policy = make_policy(kernel, output.size(), 0, stream);
|
|
|
|
launch_kernel(kernel, policy, output, outStride_k, x, inStride1_k, inBcast1_k, y, inStride2_k, inBcast2_k, params);
|
|
|
|
}
|
|
|
|
|
|
|
|
GENERATE_KERNEL_DISPATCHER_2TP(eltwise_op_bcast_dispatcher, launch_eltwise_op_bcast);
|
|
|
|
|
2020-07-09 21:02:21 +08:00
|
|
|
template <class T, class EltwiseOp> static
|
2021-10-04 15:08:45 +08:00
|
|
|
void eltwise_op(const Stream& stream, TensorSpan<T> output, TensorView<T> x, TensorView<T> y, const typename EltwiseOp::Params& params = {}) {
|
|
|
|
if (is_shape_same(output, x) && is_shape_same(output, y))
|
|
|
|
{
|
|
|
|
/* no broadcasting; use fast path */
|
|
|
|
CV_Assert(x.size() == y.size());
|
|
|
|
CV_Assert(x.size() == output.size());
|
|
|
|
|
|
|
|
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
|
|
|
|
launch_vectorized_eltwise_op<T, EltwiseOp, 4>(stream, output, x, y, params);
|
|
|
|
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
|
|
|
|
launch_vectorized_eltwise_op<T, EltwiseOp, 2>(stream, output, x, y, params);
|
|
|
|
} else {
|
|
|
|
launch_vectorized_eltwise_op<T, EltwiseOp, 1>(stream, output, x, y, params);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
CV_Assert(is_shape_compatible(output, x));
|
|
|
|
CV_Assert(is_shape_compatible(output, y));
|
|
|
|
|
|
|
|
/* matching singleton axes in both input tensors can be eliminated
|
|
|
|
*
|
|
|
|
* Reasoning:
|
|
|
|
* ----------
|
|
|
|
* Singleton axes do not contribute towards address calculation. They are redundant
|
|
|
|
* unless there is broadcasting. If both input tensors have singleton axis at a
|
|
|
|
* specified position, there is no broadcasting on that axis.
|
|
|
|
*
|
|
|
|
* Example:
|
|
|
|
* ---------
|
|
|
|
* x: [1, 256, 32, 32] -> [256, 32, 32]
|
|
|
|
* y: [1, 256, 1, 1] -> [256, 1, 1]
|
|
|
|
*/
|
|
|
|
for (int r = 0; r < output.rank(); r++)
|
|
|
|
{
|
|
|
|
while (x.get_axis_size(r) == 1 && y.get_axis_size(r) == 1) {
|
|
|
|
CV_Assert(output.get_axis_size(r) == 1);
|
|
|
|
|
|
|
|
x.squeeze(r);
|
|
|
|
y.squeeze(r);
|
|
|
|
output.squeeze(r);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto inShape1 = x.shape_as_vector();
|
|
|
|
auto inShape2 = y.shape_as_vector();
|
|
|
|
auto outShape = output.shape_as_vector();
|
|
|
|
|
|
|
|
/* contiguous axes that do not broadcast can be merged into one axis
|
|
|
|
*
|
|
|
|
* Example:
|
|
|
|
* ---------
|
|
|
|
* x: [32, 8, 8] -> [32, 64]
|
|
|
|
* y: [1, 8, 8] -> [1, 64]
|
|
|
|
*/
|
|
|
|
for (int i = 0; i < inShape1.size(); i++) {
|
|
|
|
/* check if axis `i` requires any broadcasting */
|
|
|
|
if (inShape1[i] == inShape2[i]) {
|
|
|
|
/* loop invariant: `i` is the first axis in the contiguous axis sequence */
|
|
|
|
|
|
|
|
int j = i + 1; /* `j` is the axis which we will attempt to merge */
|
|
|
|
while (j < inShape1.size() && inShape1[j] == inShape2[j]) {
|
|
|
|
CV_Assert(outShape[j] == inShape1[j]);
|
|
|
|
|
|
|
|
/* `j` axis is also used fully; merge `i` and `j` */
|
|
|
|
auto new_size = inShape1[i] * inShape1[j];
|
|
|
|
inShape1[i] = new_size;
|
|
|
|
inShape2[i] = new_size;
|
|
|
|
|
|
|
|
/* delete axis `j` */
|
|
|
|
inShape1.erase(std::begin(inShape1) + j);
|
|
|
|
inShape2.erase(std::begin(inShape2) + j);
|
|
|
|
outShape.erase(std::begin(outShape) + j);
|
|
|
|
|
|
|
|
/* optimizations should not break the invariants */
|
|
|
|
CV_Assert(inShape1.size() == outShape.size());
|
|
|
|
CV_Assert(inShape2.size() == outShape.size());
|
|
|
|
CV_Assert(inShape1[i] == outShape[i]);
|
|
|
|
CV_Assert(inShape2[i] == outShape[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/* contiguous broadcasting axes on the same tensor can be merged into one axis
|
|
|
|
*
|
|
|
|
* Example:
|
|
|
|
* ---------
|
|
|
|
* x: [256, 8, 8] -> [256, 64]
|
|
|
|
* y: [256, 1, 1] -> [256, 1]
|
|
|
|
*/
|
|
|
|
for (int i = 0; i < inShape1.size(); i++) {
|
|
|
|
/* check if axis `i` requires any broadcasting in tensor 1 */
|
|
|
|
if (inShape1[i] == 1 && inShape2[i] != 1) {
|
|
|
|
/* loop invariant: `i` is the first axis in the contiguous axis sequence */
|
|
|
|
|
|
|
|
int j = i + 1; /* `j` is the axis which we will attempt to merge */
|
|
|
|
while (j < inShape1.size() && inShape1[j] == 1 && inShape2[j] != 1) {
|
|
|
|
CV_Assert(outShape[j] == inShape2[j]);
|
|
|
|
|
|
|
|
/* `j` axis is also used fully; merge `i` and `j` */
|
|
|
|
inShape1[i] = 1;
|
|
|
|
inShape2[i] = inShape2[i] * inShape2[j];
|
|
|
|
outShape[i] = inShape2[i];
|
|
|
|
|
|
|
|
/* delete axis `j` */
|
|
|
|
inShape1.erase(std::begin(inShape1) + j);
|
|
|
|
inShape2.erase(std::begin(inShape2) + j);
|
|
|
|
outShape.erase(std::begin(outShape) + j);
|
|
|
|
|
|
|
|
/* optimizations should not break the invariants */
|
|
|
|
CV_Assert(inShape1.size() == outShape.size());
|
|
|
|
CV_Assert(inShape2.size() == outShape.size());
|
|
|
|
CV_Assert(inShape1[i] == 1);
|
|
|
|
CV_Assert(inShape2[i] == outShape[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/* check if axis `i` requires any broadcasting in tensor 2 */
|
|
|
|
if (inShape1[i] != 1 && inShape2[i] == 1) {
|
|
|
|
/* loop invariant: `i` is the first axis in the contiguous axis sequence */
|
|
|
|
|
|
|
|
int j = i + 1; /* `j` is the axis which we will attempt to merge */
|
|
|
|
while (j < inShape1.size() && inShape1[j] != 1 && inShape2[j] == 1) {
|
|
|
|
CV_Assert(outShape[j] == inShape1[j]);
|
|
|
|
|
|
|
|
/* `j` axis is also used fully; merge `i` and `j` */
|
|
|
|
inShape1[i] = inShape1[i] * inShape1[j];
|
|
|
|
inShape2[i] = 1;
|
|
|
|
outShape[i] = inShape1[i];
|
|
|
|
|
|
|
|
/* delete axis `j` */
|
|
|
|
inShape1.erase(std::begin(inShape1) + j);
|
|
|
|
inShape2.erase(std::begin(inShape2) + j);
|
|
|
|
outShape.erase(std::begin(outShape) + j);
|
|
|
|
|
|
|
|
/* optimizations should not break the invariants */
|
|
|
|
CV_Assert(inShape1.size() == outShape.size());
|
|
|
|
CV_Assert(inShape2.size() == outShape.size());
|
|
|
|
CV_Assert(inShape1[i] == outShape[i]);
|
|
|
|
CV_Assert(inShape2[i] == 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto rank = outShape.size();
|
|
|
|
|
|
|
|
std::vector<std::size_t> inStride1(rank), inStride2(rank), outStride(rank);
|
|
|
|
inStride1.back() = 1;
|
|
|
|
inStride2.back() = 1;
|
|
|
|
outStride.back() = 1;
|
|
|
|
/* garbage, ..., garbage, 1 */
|
|
|
|
|
|
|
|
std::copy(std::begin(inShape1) + 1, std::end(inShape1), std::begin(inStride1));
|
|
|
|
std::copy(std::begin(inShape2) + 1, std::end(inShape2), std::begin(inStride2));
|
|
|
|
std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
|
|
|
|
/* dim[0], dim[1], ..., dim[-1], 1 */
|
|
|
|
|
|
|
|
std::partial_sum(inStride1.rbegin(), inStride1.rend(), inStride1.rbegin(), std::multiplies<std::size_t>());
|
|
|
|
std::partial_sum(inStride2.rbegin(), inStride2.rend(), inStride2.rbegin(), std::multiplies<std::size_t>());
|
|
|
|
std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<std::size_t>());
|
|
|
|
/* stride[0], stride[1], ..., stride[-2], 1 */
|
|
|
|
|
|
|
|
std::vector<int> inBcast1(rank), inBcast2(rank);
|
|
|
|
std::transform(std::begin(inShape1), std::end(inShape1), std::begin(inBcast1), [](std::size_t sz) { return sz == 1; });
|
|
|
|
std::transform(std::begin(inShape2), std::end(inShape2), std::begin(inBcast2), [](std::size_t sz) { return sz == 1; });
|
2020-02-29 16:46:14 +08:00
|
|
|
|
2021-10-04 15:08:45 +08:00
|
|
|
CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK);
|
|
|
|
eltwise_op_bcast_dispatcher<T, EltwiseOp, 1, CSL_MAX_TENSOR_RANK>(rank, stream, output, outStride, x, inStride1, inBcast1, y, inStride2, inBcast2, params);
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
}
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
template <class T>
|
2021-10-04 15:08:45 +08:00
|
|
|
void eltwise_max_2(const Stream& stream, TensorSpan<T> output, TensorView<T> x, TensorView<T> y) {
|
2020-07-09 21:02:21 +08:00
|
|
|
eltwise_op<T, MaxFunctor<T>>(stream, output, x, y);
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
|
2021-09-22 20:17:37 +08:00
|
|
|
template <class T>
|
2021-10-04 15:08:45 +08:00
|
|
|
void eltwise_min_2(const Stream& stream, TensorSpan<T> output, TensorView<T> x, TensorView<T> y) {
|
2021-09-22 20:17:37 +08:00
|
|
|
eltwise_op<T, MinFunctor<T>>(stream, output, x, y);
|
|
|
|
}
|
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
template <class T>
|
2021-10-04 15:08:45 +08:00
|
|
|
void eltwise_sum_2(const Stream& stream, TensorSpan<T> output, TensorView<T> x, TensorView<T> y) {
|
2020-07-09 21:02:21 +08:00
|
|
|
eltwise_op<T, SumFunctor<T>>(stream, output, x, y);
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
2019-12-06 23:58:36 +08:00
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
template <class T>
|
2021-10-04 15:08:45 +08:00
|
|
|
void eltwise_sum_coeff_2(const Stream& stream, TensorSpan<T> output, T coeff_x, TensorView<T> x, T coeff_y, TensorView<T> y) {
|
2020-07-09 21:02:21 +08:00
|
|
|
eltwise_op<T, ScaledSumFunctor<T>>(stream, output, x, y, {coeff_x, coeff_y});
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
2019-12-06 23:58:36 +08:00
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
template <class T>
|
2021-10-04 15:08:45 +08:00
|
|
|
void eltwise_prod_2(const Stream& stream, TensorSpan<T> output, TensorView<T> x, TensorView<T> y) {
|
2020-07-09 21:02:21 +08:00
|
|
|
eltwise_op<T, ProductFunctor<T>>(stream, output, x, y);
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
2019-12-06 23:58:36 +08:00
|
|
|
|
2020-02-29 16:46:14 +08:00
|
|
|
template <class T>
|
2021-10-04 15:08:45 +08:00
|
|
|
void eltwise_div_2(const Stream& stream, TensorSpan<T> output, TensorView<T> x, TensorView<T> y) {
|
2020-07-09 21:02:21 +08:00
|
|
|
eltwise_op<T, DivFunctor<T>>(stream, output, x, y);
|
2020-02-29 16:46:14 +08:00
|
|
|
}
|
2019-12-06 23:58:36 +08:00
|
|
|
|
2020-01-15 23:28:37 +08:00
|
|
|
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
|
2021-10-04 15:08:45 +08:00
|
|
|
template void eltwise_div_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y);
|
|
|
|
template void eltwise_prod_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y);
|
|
|
|
template void eltwise_sum_coeff_2(const Stream&, TensorSpan<__half>, __half, TensorView<__half>, __half, TensorView<__half>);
|
|
|
|
template void eltwise_sum_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y);
|
|
|
|
template void eltwise_max_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y);
|
|
|
|
template void eltwise_min_2(const Stream& stream, TensorSpan<__half> output, TensorView<__half> x, TensorView<__half> y);
|
2020-01-15 23:28:37 +08:00
|
|
|
#endif
|
2021-10-04 15:08:45 +08:00
|
|
|
template void eltwise_div_2(const Stream& stream, TensorSpan<float> output, TensorView<float> x, TensorView<float> y);
|
|
|
|
template void eltwise_prod_2(const Stream& stream, TensorSpan<float> output, TensorView<float> x, TensorView<float> y);
|
|
|
|
template void eltwise_sum_coeff_2(const Stream&, TensorSpan<float>, float, TensorView<float>, float, TensorView<float>);
|
|
|
|
template void eltwise_sum_2(const Stream& stream, TensorSpan<float> output, TensorView<float> x, TensorView<float> y);
|
|
|
|
template void eltwise_max_2(const Stream& stream, TensorSpan<float> output, TensorView<float> x, TensorView<float> y);
|
|
|
|
template void eltwise_min_2(const Stream& stream, TensorSpan<float> output, TensorView<float> x, TensorView<float> y);
|
2019-12-06 23:58:36 +08:00
|
|
|
|
Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low
CUDA backend for the DNN module
* stub cuda4dnn design
* minor fixes for tests and doxygen
* add csl public api directory to module headers
* add low-level CSL components
* add high-level CSL components
* integrate csl::Tensor into backbone code
* switch to CPU iff unsupported; otherwise, fail on error
* add fully connected layer
* add softmax layer
* add activation layers
* support arbitary rank TensorDescriptor
* pass input wrappers to `initCUDA()`
* add 1d/2d/3d-convolution
* add pooling layer
* reorganize and refactor code
* fixes for gcc, clang and doxygen; remove cxx14/17 code
* add blank_layer
* add LRN layer
* add rounding modes for pooling layer
* split tensor.hpp into tensor.hpp and tensor_ops.hpp
* add concat layer
* add scale layer
* add batch normalization layer
* split math.cu into activations.cu and math.hpp
* add eltwise layer
* add flatten layer
* add tensor transform api
* add asymmetric padding support for convolution layer
* add reshape layer
* fix rebase issues
* add permute layer
* add padding support for concat layer
* refactor and reorganize code
* add normalize layer
* optimize bias addition in scale layer
* add prior box layer
* fix and optimize normalize layer
* add asymmetric padding support for pooling layer
* add event API
* improve pooling performance for some padding scenarios
* avoid over-allocation of compute resources to kernels
* improve prior box performance
* enable layer fusion
* add const layer
* add resize layer
* add slice layer
* add padding layer
* add deconvolution layer
* fix channelwise ReLU initialization
* add vector traits
* add vectorized versions of relu, clipped_relu, power
* add vectorized concat kernels
* improve concat_with_offsets performance
* vectorize scale and bias kernels
* add support for multi-billion element tensors
* vectorize prior box kernels
* fix address alignment check
* improve bias addition performance of conv/deconv/fc layers
* restructure code for supporting multiple targets
* add DNN_TARGET_CUDA_FP64
* add DNN_TARGET_FP16
* improve vectorization
* add region layer
* improve tensor API, add dynamic ranks
1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
- size_range: computes the combined size of for a given axis range
- tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability
* fix parametric relu activation
* add squeeze/unsqueeze tensor API
* add reorg layer
* optimize permute and enable 2d permute
* enable 1d and 2d slice
* add split layer
* add shuffle channel layer
* allow tensors of different ranks in reshape primitive
* patch SliceOp to allow Crop Layer
* allow extra shape inputs in reshape layer
* use `std::move_backward` instead of `std::move` for insert in resizable_static_array
* improve workspace management
* add spatial LRN
* add nms (cpu) to region layer
* add max pooling with argmax ( and a fix to limits.hpp)
* add max unpooling layer
* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
* update supportBackend to be more rigorous
* remove stray include from preventing non-cuda build
* include op_cuda.hpp outside condition #if
* refactoring, fixes and many optimizations
* drop DNN_TARGET_CUDA_FP64
* fix gcc errors
* increase max. tensor rank limit to six
* add Interp layer
* drop custom layers; use BackendNode
* vectorize activation kernels
* fixes for gcc
* remove wrong assertion
* fix broken assertion in unpooling primitive
* fix build errors in non-CUDA build
* completely remove workspace from public API
* fix permute layer
* enable accuracy and perf. tests for DNN_TARGET_CUDA
* add asynchronous forward
* vectorize eltwise ops
* vectorize fill kernel
* fixes for gcc
* remove CSL headers from public API
* remove csl header source group from cmake
* update min. cudnn version in cmake
* add numerically stable FP32 log1pexp
* refactor code
* add FP16 specialization to cudnn based tensor addition
* vectorize scale1 and bias1 + minor refactoring
* fix doxygen build
* fix invalid alignment assertion
* clear backend wrappers before allocateLayers
* ignore memory lock failures
* do not allocate internal blobs
* integrate NVTX
* add numerically stable half precision log1pexp
* fix indentation, following coding style, improve docs
* remove accidental modification of IE code
* Revert "add asynchronous forward"
This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
* [cmake] throw error for unsupported CC versions
* fix rebase issues
* add more docs, refactor code, fix bugs
* minor refactoring and fixes
* resolve warnings/errors from clang
* remove haveCUDA() checks from supportBackend()
* remove NVTX integration
* changes based on review comments
* avoid exception when no CUDA device is present
* add color code for CUDA in Net::dump
2019-10-21 19:28:00 +08:00
|
|
|
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|