mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 11:10:21 +08:00
Merge pull request #24985 from Dhanwanth1803:hardswish
Fixes #24974 support HardSwishInt8 #24985 As given very clearly in the issue #24974 I made the required 2 changes to implement HardSwish Layer in INT8. Requesting comments. resolves https://github.com/opencv/opencv/issues/24974 - [X] I agree to contribute to the project under Apache 2 License. - [X] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [X] The PR is proposed to the proper branch - [X] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMake Co-authored-by: Dhanwanth1803 <dhanwanthvarala@gmail,com>
This commit is contained in:
parent
bd73b7bcf5
commit
12aa0fe898
@ -212,6 +212,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(SigmoidInt8, ActivationLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(TanHInt8, ActivationLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(SwishInt8, ActivationLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(HardSwishInt8, ActivationLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(MishInt8, ActivationLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ELUInt8, ActivationLayerInt8);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(BNLLInt8, ActivationLayerInt8);
|
||||
|
@ -267,6 +267,8 @@ public:
|
||||
res = std::make_shared<ngraph::op::Elu>(input, 1.0f);
|
||||
} else if (type == "MishInt8") {
|
||||
res = std::make_shared<ngraph::op::v4::Mish>(input);
|
||||
} else if (type == "HardSwishInt8") {
|
||||
res = std::make_shared<ngraph::op::v4::HSwish>(input);
|
||||
} else if (type == "AbsValInt8") {
|
||||
res = std::make_shared<ngraph::op::Abs>(input);
|
||||
} else if (type == "SigmoidInt8") {
|
||||
|
@ -939,6 +939,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
|
||||
y = std::min(std::max(x, 0.f), 6.f);
|
||||
else if (opcode == "LOGISTIC")
|
||||
y = 1.0f / (1.0f + std::exp(-x));
|
||||
else if (opcode == "HARD_SWISH")
|
||||
y = x * max(0.f, min(1.f, x / 6.f + 0.5f));
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Lookup table for " + opcode);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user