Merge pull request #25273 from dkurt:tflite_new_layers

TFLite new layers #25273

### Pull Request Readiness Checklist

resolves https://github.com/opencv/opencv/issues/25272, https://github.com/opencv/opencv/issues/24965

**Merge with extra**: https://github.com/opencv/opencv_extra/pull/1160

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Dmitry Kurtaev 2024-03-29 11:21:13 +03:00 committed by GitHub
parent afb91b552e
commit 01dc010436
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 1 deletions

View File

@ -66,6 +66,10 @@ private:
void parseDequantize(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseDequantize(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseDetectionPostProcess(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseDetectionPostProcess(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseSplit(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseFullyConnected(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseFusedActivation(const Operator& op, ActivationFunctionType activ); void parseFusedActivation(const Operator& op, ActivationFunctionType activ);
void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused);
@ -109,7 +113,7 @@ Mat TFLiteImporter::parseTensor(const Tensor& tensor)
default: default:
CV_Error(Error::StsNotImplemented, format("Parse tensor with type %s", EnumNameTensorType(tensor.type()))); CV_Error(Error::StsNotImplemented, format("Parse tensor with type %s", EnumNameTensorType(tensor.type())));
} }
return Mat(shape, dtype, const_cast<void*>(data)); return shape.empty() ? Mat() : Mat(shape, dtype, const_cast<void*>(data));
} }
TFLiteImporter::TFLiteImporter(Net& dstNet, const char* modelBuffer, size_t bufSize) TFLiteImporter::TFLiteImporter(Net& dstNet, const char* modelBuffer, size_t bufSize)
@ -275,6 +279,10 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
dispatch["Convolution2DTransposeBias"] = &TFLiteImporter::parseDeconvolution; dispatch["Convolution2DTransposeBias"] = &TFLiteImporter::parseDeconvolution;
dispatch["QUANTIZE"] = &TFLiteImporter::parseQuantize; dispatch["QUANTIZE"] = &TFLiteImporter::parseQuantize;
dispatch["DEQUANTIZE"] = &TFLiteImporter::parseDequantize; dispatch["DEQUANTIZE"] = &TFLiteImporter::parseDequantize;
dispatch["SPLIT"] = &TFLiteImporter::parseSplit;
dispatch["FULLY_CONNECTED"] = &TFLiteImporter::parseFullyConnected;
dispatch["SOFTMAX"] = &TFLiteImporter::parseSoftmax;
dispatch["CAST"] = &TFLiteImporter::parseCast;
dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess; dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess;
return dispatch; return dispatch;
} }
@ -809,6 +817,38 @@ void TFLiteImporter::parseDequantize(const Operator& op, const std::string& opco
addLayer(layerParams, op); addLayer(layerParams, op);
} }
void TFLiteImporter::parseSplit(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Slice";
auto options = op.builtin_options_as_SplitOptions();
CV_Assert(options);
layerParams.set("num_split", options->num_splits());
addLayer(layerParams, op);
}
void TFLiteImporter::parseFullyConnected(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Gemm";
auto options = op.builtin_options_as_FullyConnectedOptions();
CV_Assert(options);
int idx = op.inputs()->Get(1);
Mat weights = allTensors[idx];
layerParams.blobs.resize(1, weights);
layerParams.set("transB", true);
layerParams.set("constB", true);
addLayer(layerParams, op);
parseFusedActivation(op, options->fused_activation_function());
}
void TFLiteImporter::parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Softmax";
addLayer(layerParams, op);
}
void TFLiteImporter::parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Identity";
addLayer(layerParams, op);
}
void TFLiteImporter::parseDetectionPostProcess(const Operator& op, const std::string& opcode, LayerParams& layerParams) { void TFLiteImporter::parseDetectionPostProcess(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
// Parse parameters; // Parse parameters;
std::vector<std::string> keys(1, ""); std::vector<std::string> keys(1, "");

View File

@ -235,6 +235,18 @@ TEST_P(Test_TFLite, replicate_by_pack) {
testLayer("replicate_by_pack", l1, lInf); testLayer("replicate_by_pack", l1, lInf);
} }
TEST_P(Test_TFLite, split) {
testLayer("split");
}
TEST_P(Test_TFLite, fully_connected) {
if (backend == DNN_BACKEND_CUDA)
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA, CV_TEST_TAG_DNN_SKIP_CUDA_FP16);
if (backend == DNN_BACKEND_VKCOM)
applyTestTag(CV_TEST_TAG_DNN_SKIP_VULKAN);
testLayer("fully_connected");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets());
}} // namespace }} // namespace