Merge pull request #24985 from Dhanwanth1803:hardswish

Fixes #24974 support HardSwishInt8 #24985

As given very clearly in the issue #24974 I made the required 2 changes to implement HardSwish Layer in INT8. Requesting comments.

resolves https://github.com/opencv/opencv/issues/24974

- [X] I agree to contribute to the project under Apache 2 License.
- [X] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [X] The PR is proposed to the proper branch
- [X] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [ ] The feature is well documented and sample code can be built with the project CMake

Co-authored-by: Dhanwanth1803 <dhanwanthvarala@gmail,com>
This commit is contained in:
Dhanwanth1803 2024-02-16 20:49:29 +05:30 committed by GitHub
parent bd73b7bcf5
commit 12aa0fe898
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 0 deletions

View File

@ -212,6 +212,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(SigmoidInt8, ActivationLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(TanHInt8, ActivationLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(SwishInt8, ActivationLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(HardSwishInt8, ActivationLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(MishInt8, ActivationLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(ELUInt8, ActivationLayerInt8);
CV_DNN_REGISTER_LAYER_CLASS(BNLLInt8, ActivationLayerInt8);

View File

@ -267,6 +267,8 @@ public:
res = std::make_shared<ngraph::op::Elu>(input, 1.0f);
} else if (type == "MishInt8") {
res = std::make_shared<ngraph::op::v4::Mish>(input);
} else if (type == "HardSwishInt8") {
res = std::make_shared<ngraph::op::v4::HSwish>(input);
} else if (type == "AbsValInt8") {
res = std::make_shared<ngraph::op::Abs>(input);
} else if (type == "SigmoidInt8") {

View File

@ -939,6 +939,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
y = std::min(std::max(x, 0.f), 6.f);
else if (opcode == "LOGISTIC")
y = 1.0f / (1.0f + std::exp(-x));
else if (opcode == "HARD_SWISH")
y = x * max(0.f, min(1.f, x / 6.f + 0.5f));
else
CV_Error(Error::StsNotImplemented, "Lookup table for " + opcode);