Merge pull request #17621 from YashasSamaga:cuda4dnn-mish-half

This commit is contained in:
Alexander Alekhin 2020-06-23 18:44:50 +00:00
commit 6fb46bb34b

View File

@ -57,11 +57,19 @@ struct mish_functor<float> {
auto n = e * e + 2 * e; auto n = e * e + 2 * e;
if (value <= -0.6f) if (value <= -0.6f)
return value * fast_divide(n, n + 2); return value * fast_divide(n, n + 2);
return value - 2 * fast_divide(value, n + 2); return value - 2 * fast_divide(value, n + 2);
} }
}; };
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <>
struct mish_functor<__half> {
__device__ __half operator()(__half value) {
return mish_functor<float>()(value);
}
};
#endif
template <class T> template <class T>
struct sigmoid_functor { struct sigmoid_functor {
__device__ T operator()(T value) { __device__ T operator()(T value) {