Merge pull request #17534 from YashasSamaga:cuda4dnn-remove-unused-funcs

cuda4dnn: reduce CUDA version requirements to at least CUDA 9.2

* remove half2 specializations

* do not remove atomicAdd for half in CUDA 10 and below

* remove fp16.hpp
This commit is contained in:
Yashas Samaga B L 2020-06-17 14:37:52 +05:30 committed by GitHub
parent 6fdddd53a1
commit 9ba5581d17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2 additions and 128 deletions

View File

@ -9,8 +9,9 @@
#include <cuda_fp16.h>
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
// This function was introduced in CUDA 10.
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700 && CUDART_VERSION >= 10000)
// And half-precision floating-point operations are not supported by devices of compute capability strictly lower than 5.3
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
#elif __CUDA_ARCH__ < 530

View File

@ -11,20 +11,12 @@
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
template <class T> __device__ T abs(T val) { return (val < T(0) ? -val : val); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half2 abs(__half2 val) {
val.x = abs(val.x);
val.y = abs(val.y);
return val;
}
#endif
template <> inline __device__ float abs(float val) { return fabsf(val); }
template <> inline __device__ double abs(double val) { return fabs(val); }
template <class T> __device__ T exp(T val);
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half exp(__half val) { return hexp(val); }
template <> inline __device__ __half2 exp(__half2 val) { return h2exp(val); }
#endif
template <> inline __device__ float exp(float val) { return expf(val); }
template <> inline __device__ double exp(double val) { return ::exp(val); }
@ -32,37 +24,21 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <class T> __device__ T expm1(T val);
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half expm1(__half val) { return hexp(val) - __half(1); }
template <> inline __device__ __half2 expm1(__half2 val) { return h2exp(val) - __half2(1, 1); }
#endif
template <> inline __device__ float expm1(float val) { return expm1f(val); }
template <> inline __device__ double expm1(double val) { return ::expm1(val); }
template <class T> __device__ T max(T x, T y) { return (x > y ? x : y); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half2 max(__half2 a, __half2 b) {
a.x = max(a.x, a.x);
a.y = max(a.y, b.y);
return a;
}
#endif
template <> inline __device__ float max(float x, float y) { return fmaxf(x, y); }
template <> inline __device__ double max(double x, double y) { return fmax(x, y); }
template <class T> __device__ T min(T x, T y) { return (x > y ? y : x); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half2 min(__half2 a, __half2 b) {
a.x = min(a.x, a.x);
a.y = min(a.y, b.y);
return a;
}
#endif
template <> inline __device__ float min(float x, float y) { return fminf(x, y); }
template <> inline __device__ double min(double x, double y) { return fmin(x, y); }
template <class T> __device__ T log1p(T val);
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half log1p(__half val) { return hlog(__half(1) + val); }
template <> inline __device__ __half2 log1p(__half2 val) { return h2log(__half2(1, 1) + val); }
#endif
template <> inline __device__ float log1p(float val) { return log1pf(val); }
@ -78,11 +54,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
else
return val;
}
template <> inline __device__ __half2 log1pexp(__half2 val) {
val.x = log1pexp(val.x);
val.y = log1pexp(val.y);
return val;
}
#endif
template <> inline __device__ float log1pexp(float val) {
if (val <= -20)
@ -108,7 +79,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <class T> __device__ T tanh(T val);
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half tanh(__half val) { return tanhf(val); }
template <> inline __device__ __half2 tanh(__half2 val) { return __half2(tanh(val.x), tanh(val.y)); }
#endif
template <> inline __device__ float tanh(float val) { return tanhf(val); }
template <> inline __device__ double tanh(double val) { return ::tanh(val); }
@ -116,7 +86,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <class T> __device__ T pow(T val, T exp);
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half pow(__half val, __half exp) { return powf(val, exp); }
template <> inline __device__ __half2 pow(__half2 val, __half2 exp) { return __half2(pow(val.x, exp.x), pow(val.y, exp.y)); }
#endif
template <> inline __device__ float pow(float val, float exp) { return powf(val, exp); }
template <> inline __device__ double pow(double val, double exp) { return ::pow(val, exp); }
@ -124,7 +93,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <class T> __device__ T sqrt(T val);
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half sqrt(__half val) { return hsqrt(val); }
template <> inline __device__ __half2 sqrt(__half2 val) { return h2sqrt(val); }
#endif
template <> inline __device__ float sqrt(float val) { return sqrtf(val); }
template <> inline __device__ double sqrt(double val) { return ::sqrt(val); }
@ -132,15 +100,11 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <class T> __device__ T rsqrt(T val);
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half rsqrt(__half val) { return hrsqrt(val); }
template <> inline __device__ __half2 rsqrt(__half2 val) { return h2rsqrt(val); }
#endif
template <> inline __device__ float rsqrt(float val) { return rsqrtf(val); }
template <> inline __device__ double rsqrt(double val) { return ::rsqrt(val); }
template <class T> __device__ T sigmoid(T val) { return T(1) / (T(1) + exp(-val)); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half2 sigmoid(__half2 val) { return __half2(1, 1) / (__half2(1, 1) + exp(__hneg2(val))); }
#endif
template <class T> __device__ T clamp(T value, T lower, T upper) { return min(max(value, lower), upper); }
@ -149,7 +113,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <> inline __device__ float round(float value) { return roundf(value); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half round(__half value) { return hrint(value); }
template <> inline __device__ __half2 round(__half2 value) { return h2rint(value); }
#endif
template <class T> __device__ T ceil(T value);
@ -157,7 +120,6 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template <> inline __device__ float ceil(float value) { return ceilf(value); }
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <> inline __device__ __half ceil(__half value) { return hceil(value); }
template <> inline __device__ __half2 ceil(__half2 value) { return h2ceil(value); }
#endif
template <class T> __device__ T fast_divide(T x, T y) { return x / y; }

View File

@ -8,7 +8,6 @@
#include "error.hpp"
#include "stream.hpp"
#include "pointer.hpp"
#include "fp16.hpp"
#include <opencv2/core.hpp>

View File

@ -5,7 +5,6 @@
#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CUDNN_HPP
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CUDNN_HPP
#include "../fp16.hpp"
#include "../pointer.hpp"
#include <cudnn.h>

View File

@ -1,86 +0,0 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_FP16_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_CSL_FP16_HPP
#include "nvcc_defs.hpp"
#include <cuda_fp16.h>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
namespace detail {
template <class T, class = void>
struct is_half_convertible : std::false_type { };
template <class T>
struct is_half_convertible<T, typename std::enable_if<std::is_integral<T>::value, void>::type> : std::true_type { };
template <class T>
struct is_half_convertible<T, typename std::enable_if<std::is_floating_point<T>::value, void>::type> : std::true_type { };
}
/* Note: nvcc has a broken overload resolution; it considers host overloads inside device code
CUDA4DNN_HOST bool operator==(half lhs, half rhs) noexcept { return static_cast<float>(lhs) == static_cast<float>(rhs); }
CUDA4DNN_HOST bool operator!=(half lhs, half rhs) noexcept { return static_cast<float>(lhs) != static_cast<float>(rhs); }
CUDA4DNN_HOST bool operator<(half lhs, half rhs) noexcept { return static_cast<float>(lhs) < static_cast<float>(rhs); }
CUDA4DNN_HOST bool operator>(half lhs, half rhs) noexcept { return static_cast<float>(lhs) > static_cast<float>(rhs); }
CUDA4DNN_HOST bool operator<=(half lhs, half rhs) noexcept { return static_cast<float>(lhs) <= static_cast<float>(rhs); }
CUDA4DNN_HOST bool operator>=(half lhs, half rhs) noexcept { return static_cast<float>(lhs) >= static_cast<float>(rhs); }
*/
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator==(half lhs, T rhs) noexcept { return static_cast<float>(lhs) == static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator!=(half lhs, T rhs) noexcept { return static_cast<float>(lhs) != static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator<(half lhs, T rhs) noexcept { return static_cast<float>(lhs) < static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator>(half lhs, T rhs) noexcept { return static_cast<float>(lhs) > static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator<=(half lhs, T rhs) noexcept { return static_cast<float>(lhs) <= static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator>=(half lhs, T rhs) noexcept { return static_cast<float>(lhs) >= static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator==(T lhs, half rhs) noexcept { return static_cast<float>(lhs) == static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator!=(T lhs, half rhs) noexcept { return static_cast<float>(lhs) != static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator<(T lhs, half rhs) noexcept { return static_cast<float>(lhs) < static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator>(T lhs, half rhs) noexcept { return static_cast<float>(lhs) > static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator<=(T lhs, half rhs) noexcept { return static_cast<float>(lhs) <= static_cast<float>(rhs); }
template <class T> CUDA4DNN_HOST
typename std::enable_if<detail::is_half_convertible<T>::value, bool>
::type operator>=(T lhs, half rhs) noexcept { return static_cast<float>(lhs) >= static_cast<float>(rhs); }
#endif
}}}} /* namespace cv::dnn::cuda4dnn::csl */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_FP16_HPP */

View File

@ -11,7 +11,6 @@
#include "cuda4dnn/csl/cudnn.hpp"
#include "cuda4dnn/csl/tensor.hpp"
#include "cuda4dnn/csl/memory.hpp"
#include "cuda4dnn/csl/fp16.hpp"
#include "cuda4dnn/csl/workspace.hpp"
#include "cuda4dnn/kernels/fp_conversion.hpp"
#endif