mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
add ONNX OP sign, shrink and reciprocal
This commit is contained in:
parent
9aa647068b
commit
e36948cfbc
@ -794,6 +794,26 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
static Ptr<ActivationLayerInt8> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
class CV_EXPORTS SignLayer : public ActivationLayer
|
||||
{
|
||||
public:
|
||||
static Ptr<SignLayer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
class CV_EXPORTS ShrinkLayer : public ActivationLayer
|
||||
{
|
||||
public:
|
||||
float bias;
|
||||
float lambd;
|
||||
static Ptr<ShrinkLayer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
class CV_EXPORTS ReciprocalLayer : public ActivationLayer
|
||||
{
|
||||
public:
|
||||
static Ptr<ReciprocalLayer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
/* Layers used in semantic segmentation */
|
||||
|
||||
class CV_EXPORTS CropLayer : public Layer
|
||||
|
@ -248,6 +248,21 @@ void selu(const Stream& stream, Span<T> output, View<T> input, T alpha, T gamma)
|
||||
generic_op<T, SeluFunctor<T>>(stream, output, input, {alpha, gamma});
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void sign(const Stream& stream, Span<T> output, View<T> input) {
|
||||
generic_op<T, SignFunctor<T>>(stream, output, input);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void shrink(const Stream& stream, Span<T> output, View<T> input, T bias, T lambd) {
|
||||
generic_op<T, ShrinkFunctor<T>>(stream, output, input, {bias, lambd});
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void reciprocal(const Stream& stream, Span<T> output, View<T> input) {
|
||||
generic_op<T, SignFunctor<T>>(stream, output, input);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void thresholdedrelu(const Stream& stream, Span<T> output, View<T> input, T alpha) {
|
||||
generic_op<T, ThresholdedReluFunctor<T>>(stream, output, input, {alpha});
|
||||
@ -312,6 +327,9 @@ template void selu<__half>(const Stream&, Span<__half>, View<__half>, __half, __
|
||||
template void thresholdedrelu<__half>(const Stream&, Span<__half>, View<__half>, __half);
|
||||
template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
|
||||
template void exp<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
|
||||
template void sign<__half>(const Stream&, Span<__half>, View<__half>);
|
||||
template void shrink<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
|
||||
template void reciprocal<__half>(const Stream&, Span<__half>, View<__half>);
|
||||
#endif
|
||||
|
||||
|
||||
@ -351,6 +369,9 @@ template void selu<float>(const Stream&, Span<float>, View<float>, float, float)
|
||||
template void thresholdedrelu<float>(const Stream&, Span<float>, View<float>, float);
|
||||
template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
|
||||
template void exp<float>(const Stream&, Span<float>, View<float>, float, float);
|
||||
template void sign<float>(const Stream&, Span<float>, View<float>);
|
||||
template void shrink<float>(const Stream&, Span<float>, View<float>, float, float);
|
||||
template void reciprocal<float>(const Stream&, Span<float>, View<float>);
|
||||
|
||||
template <class T, std::size_t N> static
|
||||
void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {
|
||||
|
@ -726,6 +726,50 @@ struct DivFunctor {
|
||||
CUDA4DNN_DEVICE T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct SignFunctor {
|
||||
struct Params {
|
||||
CUDA4DNN_HOST_DEVICE Params() {}
|
||||
};
|
||||
|
||||
CUDA4DNN_DEVICE SignFunctor() : SignFunctor(Params{}) { }
|
||||
|
||||
CUDA4DNN_DEVICE T operator()(T value) {
|
||||
return value > T(0) ? T(1) : (value < T(0) ? T(-1) : T(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct ShrinkFunctor {
|
||||
struct Params {
|
||||
CUDA4DNN_HOST_DEVICE Params() : bias(0), lambd(0.5) { }
|
||||
CUDA4DNN_HOST_DEVICE Params(T bias_, T lambd_) : bias(bias_), lambd(lambd_) { }
|
||||
T bias, lambd;
|
||||
};
|
||||
|
||||
CUDA4DNN_DEVICE ShrinkFunctor() : bias(0), lambd(0.5) { }
|
||||
CUDA4DNN_DEVICE ShrinkFunctor(const Params& params) : bias{params.bias}, lambd{params.lambd} { }
|
||||
|
||||
CUDA4DNN_DEVICE T operator()(T value) {
|
||||
return value > lambd ? value - bias : (value < -lambd ? value + bias : T(0));
|
||||
}
|
||||
|
||||
T bias, lambd;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct ReciprocalFunctor {
|
||||
struct Params {
|
||||
CUDA4DNN_HOST_DEVICE Params() {}
|
||||
};
|
||||
|
||||
CUDA4DNN_DEVICE ReciprocalFunctor() : ReciprocalFunctor(Params{}) { }
|
||||
|
||||
CUDA4DNN_DEVICE T operator()(T value) {
|
||||
return T(1.0f)/value;
|
||||
}
|
||||
};
|
||||
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */
|
||||
|
@ -123,6 +123,14 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
|
||||
template <class T>
|
||||
void exp(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T normScale, T normShift);
|
||||
|
||||
template <class T>
|
||||
void sign(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
|
||||
|
||||
template <class T>
|
||||
void shrink(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T bias, T lambd);
|
||||
|
||||
template <class T>
|
||||
void reciprocal(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input);
|
||||
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATIONS_HPP */
|
||||
|
@ -584,6 +584,52 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
const T normScale, normShift;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class ShrinkOp final : public BaseOp<ShrinkOp, T> {
|
||||
public:
|
||||
ShrinkOp(csl::Stream stream_, T bias_, T lambd_)
|
||||
: stream(std::move(stream_)), bias{ bias_ }, lambd{ lambd_ } { }
|
||||
|
||||
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
|
||||
{
|
||||
kernels::shrink<T>(stream, output, input, bias, lambd);
|
||||
}
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
const T bias, lambd;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class SignOp final : public BaseOp<SignOp, T> {
|
||||
public:
|
||||
SignOp(csl::Stream stream_)
|
||||
: stream(std::move(stream_)) { }
|
||||
|
||||
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
|
||||
{
|
||||
kernels::sign<T>(stream, output, input);
|
||||
}
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class ReciprocalOp final : public BaseOp<ReciprocalOp, T> {
|
||||
public:
|
||||
ReciprocalOp(csl::Stream stream_)
|
||||
: stream(std::move(stream_)) { }
|
||||
|
||||
void calculate(csl::TensorSpan<T> output, csl::TensorView<T> input) const
|
||||
{
|
||||
kernels::reciprocal<T>(stream, output, input);
|
||||
}
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
};
|
||||
|
||||
}}} /* namespace cv::dnn::cuda4dnn */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ACTIVATION_HPP */
|
||||
|
@ -130,6 +130,8 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(HardSwish, HardSwishLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Sin, SinLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Sinh, SinhLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Sign, SignLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Shrink, ShrinkLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Softplus, SoftplusLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Softsign, SoftsignLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Tan, TanLayer);
|
||||
@ -144,6 +146,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Silence, BlankLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Const, ConstLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Arg, ArgLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Reciprocal, ReciprocalLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
|
||||
|
@ -2270,6 +2270,96 @@ struct ChannelsPReLUFunctor : public BaseFunctor
|
||||
int64 getFLOPSPerElement() const { return 1; }
|
||||
};
|
||||
|
||||
struct SignFunctor : public BaseDefaultFunctor<SignFunctor>
|
||||
{
|
||||
typedef SignLayer Layer;
|
||||
|
||||
bool supportBackend(int backendId, int)
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_CUDA;
|
||||
}
|
||||
|
||||
inline float calculate(float x) const
|
||||
{
|
||||
return x > 0 ? 1 : (x < 0 ? -1 : 0);
|
||||
}
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
Ptr<BackendNode> initCUDA(int target, csl::Stream stream)
|
||||
{
|
||||
return make_cuda_node<cuda4dnn::SignOp>(target, stream);
|
||||
}
|
||||
#endif
|
||||
|
||||
int64 getFLOPSPerElement() const { return 1; }
|
||||
};
|
||||
|
||||
template<>
|
||||
const char* const SignFunctor::BaseDefaultFunctor<SignFunctor>::ocl_kernel_name = "SignForward";
|
||||
|
||||
|
||||
struct ShrinkFunctor : public BaseDefaultFunctor<ShrinkFunctor>
|
||||
{
|
||||
typedef ShrinkLayer Layer;
|
||||
float bias;
|
||||
float lambd;
|
||||
|
||||
explicit ShrinkFunctor(float bias_ = 0.0f, float lambd_ = 0.5f) : bias(bias_), lambd(lambd_) {}
|
||||
|
||||
bool supportBackend(int backendId, int)
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_CUDA;
|
||||
}
|
||||
|
||||
inline float calculate(float x) const
|
||||
{
|
||||
return x > lambd ? x - bias : (x < -lambd ? x + bias : 0);
|
||||
}
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
Ptr<BackendNode> initCUDA(int target, csl::Stream stream)
|
||||
{
|
||||
return make_cuda_node<cuda4dnn::ShrinkOp>(target, stream);
|
||||
}
|
||||
#endif
|
||||
|
||||
int64 getFLOPSPerElement() const { return 1; }
|
||||
};
|
||||
|
||||
template<>
|
||||
const char* const ShrinkFunctor::BaseDefaultFunctor<ShrinkFunctor>::ocl_kernel_name = "ShrinkForward";
|
||||
|
||||
struct ReciprocalFunctor : public BaseDefaultFunctor<ReciprocalFunctor>
|
||||
{
|
||||
typedef ReciprocalLayer Layer;
|
||||
|
||||
bool supportBackend(int backendId, int)
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_CUDA;
|
||||
}
|
||||
|
||||
inline float calculate(float x) const
|
||||
{
|
||||
return 1.0/x;
|
||||
}
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
Ptr<BackendNode> initCUDA(int target, csl::Stream stream)
|
||||
{
|
||||
return make_cuda_node<cuda4dnn::ReciprocalOp>(target, stream);
|
||||
}
|
||||
#endif
|
||||
|
||||
int64 getFLOPSPerElement() const { return 1; }
|
||||
};
|
||||
|
||||
template<>
|
||||
const char* const ReciprocalFunctor::BaseDefaultFunctor<ReciprocalFunctor>::ocl_kernel_name = "ReciprocalForward";
|
||||
|
||||
|
||||
#define ACTIVATION_CREATOR_FOR(_Layer, _Functor, ...) \
|
||||
Ptr<_Layer> _Layer::create() { \
|
||||
return return Ptr<_Layer>( new ElementWiseLayer<_Functor>(_Functor()) ); }
|
||||
@ -2611,5 +2701,32 @@ Ptr<Layer> ChannelsPReLULayer::create(const LayerParams& params)
|
||||
return l;
|
||||
}
|
||||
|
||||
Ptr<SignLayer> SignLayer::create(const LayerParams& params)
|
||||
{
|
||||
Ptr<SignLayer> l(new ElementWiseLayer<SignFunctor>());
|
||||
l->setParamsFrom(params);
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
Ptr<ReciprocalLayer> ReciprocalLayer::create(const LayerParams& params)
|
||||
{
|
||||
Ptr<ReciprocalLayer> l(new ElementWiseLayer<ReciprocalFunctor>());
|
||||
l->setParamsFrom(params);
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
Ptr<ShrinkLayer> ShrinkLayer::create(const LayerParams& params)
|
||||
{
|
||||
float bias = params.get<float>("bias", 0.f);
|
||||
float lambd = params.get<float>("lambd", 0.5f);
|
||||
Ptr<ShrinkLayer> l(new ElementWiseLayer<ShrinkFunctor>(ShrinkFunctor(bias, lambd)));
|
||||
l->setParamsFrom(params);
|
||||
l->bias = bias;
|
||||
l->lambd = lambd;
|
||||
|
||||
return l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3675,8 +3675,8 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
||||
|
||||
std::vector<std::string> simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos",
|
||||
"Cosh", "Dropout", "Erf", "Exp", "Floor", "HardSigmoid", "HardSwish",
|
||||
"Identity", "Log", "Round", "Selu", "Sigmoid", "Sin", "Sinh", "Softmax",
|
||||
"Softplus", "Softsign", "Sqrt", "Tan", "ThresholdedRelu"};
|
||||
"Identity", "Log", "Round", "Reciprocal", "Selu", "Sign", "Sigmoid", "Sin", "Sinh", "Softmax",
|
||||
"Softplus", "Softsign", "Shrink", "Sqrt", "Tan", "ThresholdedRelu"};
|
||||
for (const auto& name : simpleLayers)
|
||||
{
|
||||
dispatch[name] = &ONNXImporter::parseSimpleLayers;
|
||||
|
@ -306,3 +306,26 @@ __kernel void ThresholdedReluForward(const int n, __global T* in, __global T* ou
|
||||
if(index < n)
|
||||
out[index] = (in[index] > alpha ? in[index] : 0.f);
|
||||
}
|
||||
|
||||
__kernel void ShrinkForward(const int n, __global T* in, __global T* out,
|
||||
const KERNEL_ARG_DTYPE bias,
|
||||
const KERNEL_ARG_DTYPE lambd)
|
||||
{
|
||||
int index = get_global_id(0);
|
||||
if(index < n)
|
||||
out[index] = in[index] < -lambd ? in[index] + bias : (in[index] > lambd ? in[index] - bias : 0.f);
|
||||
}
|
||||
|
||||
__kernel void SignForward(const int n, __global T* in, __global T* out)
|
||||
{
|
||||
int index = get_global_id(0);
|
||||
if(index < n)
|
||||
out[index] = in[index] > 0.f ? 1.0f : (in[index] < 0.f) ? -1.0f : 0.0f);
|
||||
}
|
||||
|
||||
__kernel void ReciprocalForward(const int n, __global T* in, __global T* out)
|
||||
{
|
||||
int index = get_global_id(0);
|
||||
if(index < n)
|
||||
out[index] = 1.0f/in[index];
|
||||
}
|
@ -337,8 +337,6 @@
|
||||
"test_range_float_type_positive_delta_expanded",
|
||||
"test_range_int32_type_negative_delta",
|
||||
"test_range_int32_type_negative_delta_expanded",
|
||||
"test_reciprocal",
|
||||
"test_reciprocal_example",
|
||||
"test_reduce_sum_default_axes_keepdims_example",
|
||||
"test_reduce_sum_default_axes_keepdims_random",
|
||||
"test_reduce_sum_do_not_keepdims_example",
|
||||
@ -479,9 +477,6 @@
|
||||
"test_shape_start_1_end_2",
|
||||
"test_shape_start_1_end_negative_1",
|
||||
"test_shape_start_negative_1",
|
||||
"test_shrink_hard",
|
||||
"test_shrink_soft",
|
||||
"test_sign",
|
||||
"test_simple_rnn_batchwise",
|
||||
"test_simple_rnn_defaults",
|
||||
"test_simple_rnn_with_initial_bias",
|
||||
|
Loading…
Reference in New Issue
Block a user