mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
Leaky RELU support for TFLite.
This commit is contained in:
parent
79faf857d9
commit
209802c9f6
@ -271,7 +271,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
|
||||
dispatch["DEPTHWISE_CONV_2D"] = &TFLiteImporter::parseDWConvolution;
|
||||
dispatch["ADD"] = dispatch["MUL"] = &TFLiteImporter::parseEltwise;
|
||||
dispatch["RELU"] = dispatch["PRELU"] = dispatch["HARD_SWISH"] =
|
||||
dispatch["LOGISTIC"] = &TFLiteImporter::parseActivation;
|
||||
dispatch["LOGISTIC"] = dispatch["LEAKY_RELU"] = &TFLiteImporter::parseActivation;
|
||||
dispatch["MAX_POOL_2D"] = dispatch["AVERAGE_POOL_2D"] = &TFLiteImporter::parsePooling;
|
||||
dispatch["MaxPoolingWithArgmax2D"] = &TFLiteImporter::parsePoolingWithArgmax;
|
||||
dispatch["MaxUnpooling2D"] = &TFLiteImporter::parseUnpooling;
|
||||
@ -1029,6 +1029,7 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
|
||||
}
|
||||
|
||||
void TFLiteImporter::parseActivation(const Operator& op, const std::string& opcode, LayerParams& activParams, bool isFused) {
|
||||
float slope = 0.;
|
||||
if (opcode == "NONE")
|
||||
return;
|
||||
else if (opcode == "RELU6")
|
||||
@ -1041,6 +1042,13 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
|
||||
activParams.type = "HardSwish";
|
||||
else if (opcode == "LOGISTIC")
|
||||
activParams.type = "Sigmoid";
|
||||
else if (opcode == "LEAKY_RELU")
|
||||
{
|
||||
activParams.type = "ReLU";
|
||||
auto options = reinterpret_cast<const LeakyReluOptions*>(op.builtin_options());
|
||||
slope = options->alpha();
|
||||
activParams.set("negative_slope", slope);
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported activation " + opcode);
|
||||
|
||||
@ -1072,6 +1080,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
|
||||
y = 1.0f / (1.0f + std::exp(-x));
|
||||
else if (opcode == "HARD_SWISH")
|
||||
y = x * max(0.f, min(1.f, x / 6.f + 0.5f));
|
||||
else if (opcode == "LEAKY_RELU")
|
||||
y = x >= 0.f ? x : slope*x;
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Lookup table for " + opcode);
|
||||
|
||||
|
@ -268,6 +268,10 @@ TEST_P(Test_TFLite, global_max_pooling_2d) {
|
||||
testLayer("global_max_pooling_2d");
|
||||
}
|
||||
|
||||
TEST_P(Test_TFLite, leakyRelu) {
|
||||
testLayer("leakyRelu");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets());
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user