diff --git a/modules/dnn/src/tflite/tflite_importer.cpp b/modules/dnn/src/tflite/tflite_importer.cpp index f0e1546306..a23bff2545 100644 --- a/modules/dnn/src/tflite/tflite_importer.cpp +++ b/modules/dnn/src/tflite/tflite_importer.cpp @@ -66,6 +66,10 @@ private: void parseDequantize(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseDetectionPostProcess(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams); + void parseSplit(const Operator& op, const std::string& opcode, LayerParams& layerParams); + void parseFullyConnected(const Operator& op, const std::string& opcode, LayerParams& layerParams); + void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams); + void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseFusedActivation(const Operator& op, ActivationFunctionType activ); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused); @@ -109,7 +113,7 @@ Mat TFLiteImporter::parseTensor(const Tensor& tensor) default: CV_Error(Error::StsNotImplemented, format("Parse tensor with type %s", EnumNameTensorType(tensor.type()))); } - return Mat(shape, dtype, const_cast(data)); + return shape.empty() ? Mat() : Mat(shape, dtype, const_cast(data)); } TFLiteImporter::TFLiteImporter(Net& dstNet, const char* modelBuffer, size_t bufSize) @@ -275,6 +279,10 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap() dispatch["Convolution2DTransposeBias"] = &TFLiteImporter::parseDeconvolution; dispatch["QUANTIZE"] = &TFLiteImporter::parseQuantize; dispatch["DEQUANTIZE"] = &TFLiteImporter::parseDequantize; + dispatch["SPLIT"] = &TFLiteImporter::parseSplit; + dispatch["FULLY_CONNECTED"] = &TFLiteImporter::parseFullyConnected; + dispatch["SOFTMAX"] = &TFLiteImporter::parseSoftmax; + dispatch["CAST"] = &TFLiteImporter::parseCast; dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess; return dispatch; } @@ -809,6 +817,38 @@ void TFLiteImporter::parseDequantize(const Operator& op, const std::string& opco addLayer(layerParams, op); } +void TFLiteImporter::parseSplit(const Operator& op, const std::string& opcode, LayerParams& layerParams) { + layerParams.type = "Slice"; + auto options = op.builtin_options_as_SplitOptions(); + CV_Assert(options); + layerParams.set("num_split", options->num_splits()); + addLayer(layerParams, op); +} + +void TFLiteImporter::parseFullyConnected(const Operator& op, const std::string& opcode, LayerParams& layerParams) { + layerParams.type = "Gemm"; + auto options = op.builtin_options_as_FullyConnectedOptions(); + CV_Assert(options); + + int idx = op.inputs()->Get(1); + Mat weights = allTensors[idx]; + layerParams.blobs.resize(1, weights); + layerParams.set("transB", true); + layerParams.set("constB", true); + addLayer(layerParams, op); + parseFusedActivation(op, options->fused_activation_function()); +} + +void TFLiteImporter::parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams) { + layerParams.type = "Softmax"; + addLayer(layerParams, op); +} + +void TFLiteImporter::parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams) { + layerParams.type = "Identity"; + addLayer(layerParams, op); +} + void TFLiteImporter::parseDetectionPostProcess(const Operator& op, const std::string& opcode, LayerParams& layerParams) { // Parse parameters; std::vector keys(1, ""); diff --git a/modules/dnn/test/test_tflite_importer.cpp b/modules/dnn/test/test_tflite_importer.cpp index 29f8bae25e..26f2c373b8 100644 --- a/modules/dnn/test/test_tflite_importer.cpp +++ b/modules/dnn/test/test_tflite_importer.cpp @@ -235,6 +235,18 @@ TEST_P(Test_TFLite, replicate_by_pack) { testLayer("replicate_by_pack", l1, lInf); } +TEST_P(Test_TFLite, split) { + testLayer("split"); +} + +TEST_P(Test_TFLite, fully_connected) { + if (backend == DNN_BACKEND_CUDA) + applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA, CV_TEST_TAG_DNN_SKIP_CUDA_FP16); + if (backend == DNN_BACKEND_VKCOM) + applyTestTag(CV_TEST_TAG_DNN_SKIP_VULKAN); + testLayer("fully_connected"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets()); }} // namespace