mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Merge pull request #21160 from rogday:elu_alpha
This commit is contained in:
commit
dad2b9aac8
@ -453,6 +453,8 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
class CV_EXPORTS ELULayer : public ActivationLayer
|
||||
{
|
||||
public:
|
||||
float alpha;
|
||||
|
||||
static Ptr<ELULayer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
|
@ -740,6 +740,9 @@ const char* const SigmoidFunctor::BaseDefaultFunctor<SigmoidFunctor>::ocl_kernel
|
||||
struct ELUFunctor : public BaseDefaultFunctor<ELUFunctor>
|
||||
{
|
||||
typedef ELULayer Layer;
|
||||
float alpha;
|
||||
|
||||
explicit ELUFunctor(float alpha_ = 1.f) : alpha(alpha_) {}
|
||||
|
||||
bool supportBackend(int backendId, int)
|
||||
{
|
||||
@ -749,14 +752,19 @@ struct ELUFunctor : public BaseDefaultFunctor<ELUFunctor>
|
||||
|
||||
inline float calculate(float x) const
|
||||
{
|
||||
return x >= 0.f ? x : exp(x) - 1.f;
|
||||
return x >= 0.f ? x : alpha * (exp(x) - 1.f);
|
||||
}
|
||||
|
||||
inline void setKernelParams(ocl::Kernel& kernel) const
|
||||
{
|
||||
kernel.set(3, alpha);
|
||||
}
|
||||
|
||||
#ifdef HAVE_HALIDE
|
||||
void attachHalide(const Halide::Expr& input, Halide::Func& top)
|
||||
{
|
||||
Halide::Var x("x"), y("y"), c("c"), n("n");
|
||||
top(x, y, c, n) = select(input >= 0.0f, input, exp(input) - 1);
|
||||
top(x, y, c, n) = select(input >= 0.0f, input, alpha * (exp(input) - 1));
|
||||
}
|
||||
#endif // HAVE_HALIDE
|
||||
|
||||
@ -770,7 +778,7 @@ struct ELUFunctor : public BaseDefaultFunctor<ELUFunctor>
|
||||
#ifdef HAVE_DNN_NGRAPH
|
||||
std::shared_ptr<ngraph::Node> initNgraphAPI(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return std::make_shared<ngraph::op::Elu>(node, 1.0);
|
||||
return std::make_shared<ngraph::op::Elu>(node, alpha);
|
||||
}
|
||||
#endif // HAVE_DNN_NGRAPH
|
||||
|
||||
@ -1263,8 +1271,10 @@ Ptr<SigmoidLayer> SigmoidLayer::create(const LayerParams& params)
|
||||
|
||||
Ptr<ELULayer> ELULayer::create(const LayerParams& params)
|
||||
{
|
||||
Ptr<ELULayer> l(new ElementWiseLayer<ELUFunctor>(ELUFunctor()));
|
||||
float alpha = params.get<float>("alpha", 1.0f);
|
||||
Ptr<ELULayer> l(new ElementWiseLayer<ELUFunctor>(ELUFunctor(alpha)));
|
||||
l->setParamsFrom(params);
|
||||
l->alpha = alpha;
|
||||
|
||||
return l;
|
||||
}
|
||||
|
@ -131,13 +131,14 @@ __kernel void PowForward(const int n, __global const T* in, __global T* out,
|
||||
out[index] = pow(shift + scale * in[index], power);
|
||||
}
|
||||
|
||||
__kernel void ELUForward(const int n, __global const T* in, __global T* out)
|
||||
__kernel void ELUForward(const int n, __global const T* in, __global T* out,
|
||||
const KERNEL_ARG_DTYPE alpha)
|
||||
{
|
||||
int index = get_global_id(0);
|
||||
if (index < n)
|
||||
{
|
||||
T src = in[index];
|
||||
out[index] = (src >= 0.f) ? src : exp(src) - 1;
|
||||
out[index] = (src >= 0.f) ? src : alpha * (exp(src) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user