mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 11:45:30 +08:00
Support Swish and Mish activations
This commit is contained in:
parent
d99d18304a
commit
660a709840
@ -579,7 +579,7 @@ struct SwishFunctor
|
||||
bool supportBackend(int backendId, int)
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_HALIDE;
|
||||
backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;;
|
||||
}
|
||||
|
||||
void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const
|
||||
@ -640,7 +640,8 @@ struct SwishFunctor
|
||||
#ifdef HAVE_DNN_NGRAPH
|
||||
std::shared_ptr<ngraph::Node> initNgraphAPI(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "");
|
||||
auto sigmoid = std::make_shared<ngraph::op::Sigmoid>(node);
|
||||
return std::make_shared<ngraph::op::v1::Multiply>(node, sigmoid);
|
||||
}
|
||||
#endif // HAVE_DNN_NGRAPH
|
||||
|
||||
@ -659,7 +660,7 @@ struct MishFunctor
|
||||
bool supportBackend(int backendId, int)
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_HALIDE;
|
||||
backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
|
||||
}
|
||||
|
||||
void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const
|
||||
@ -720,7 +721,13 @@ struct MishFunctor
|
||||
#ifdef HAVE_DNN_NGRAPH
|
||||
std::shared_ptr<ngraph::Node> initNgraphAPI(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "");
|
||||
float one = 1.0f;
|
||||
auto constant = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, ngraph::Shape{1}, &one);
|
||||
auto exp_node = std::make_shared<ngraph::op::v0::Exp>(node);
|
||||
auto sum = std::make_shared<ngraph::op::v1::Add>(constant, exp_node, ngraph::op::AutoBroadcastType::NUMPY);
|
||||
auto log_node = std::make_shared<ngraph::op::v0::Log>(sum);
|
||||
auto tanh_node = std::make_shared<ngraph::op::Tanh>(log_node);
|
||||
return std::make_shared<ngraph::op::v1::Multiply>(node, tanh_node);
|
||||
}
|
||||
#endif // HAVE_DNN_NGRAPH
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user