opencv/modules/dnn/src/dnn_read.cpp
Dmitry Kurtaev 76350cd30f
Merge pull request #23161 from dkurt:dnn_tflite
TFLite models importer

* initial commit

* Refactor TFLiteImporter

* Better FlatBuffers detection

* Add permute before 4D->3D reshape

* Track layers layout

* TFLite Convolution2DTransposeBias layer

* Skip TFLite tests without FlatBuffers

* Fix check of FlatBuffers in tests. Add readNetFromTFLite from buffer

* TFLite Max Unpooling test

* Add skip for TFLite unpooling test

* Revert DW convolution workaround

* Fix ObjC bindings

* Better errors handling

* Regenerate TFLite schema using flatc

* dnn(tflite): more checks, better logging

* Checks for unimplemented fusion. Fix tests
2023-02-13 14:00:20 +00:00

100 lines
3.7 KiB
C++

// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "precomp.hpp"
namespace cv {
namespace dnn {
CV__DNN_INLINE_NS_BEGIN
Net readNet(const String& _model, const String& _config, const String& _framework)
{
String framework = toLowerCase(_framework);
String model = _model;
String config = _config;
const std::string modelExt = model.substr(model.rfind('.') + 1);
const std::string configExt = config.substr(config.rfind('.') + 1);
if (framework == "caffe" || modelExt == "caffemodel" || configExt == "caffemodel" || modelExt == "prototxt" || configExt == "prototxt")
{
if (modelExt == "prototxt" || configExt == "caffemodel")
std::swap(model, config);
return readNetFromCaffe(config, model);
}
if (framework == "tensorflow" || modelExt == "pb" || configExt == "pb" || modelExt == "pbtxt" || configExt == "pbtxt")
{
if (modelExt == "pbtxt" || configExt == "pb")
std::swap(model, config);
return readNetFromTensorflow(model, config);
}
if (framework == "tflite" || modelExt == "tflite")
{
return readNetFromTFLite(model);
}
if (framework == "torch" || modelExt == "t7" || modelExt == "net" || configExt == "t7" || configExt == "net")
{
return readNetFromTorch(model.empty() ? config : model);
}
if (framework == "darknet" || modelExt == "weights" || configExt == "weights" || modelExt == "cfg" || configExt == "cfg")
{
if (modelExt == "cfg" || configExt == "weights")
std::swap(model, config);
return readNetFromDarknet(config, model);
}
if (framework == "dldt" || modelExt == "bin" || configExt == "bin" || modelExt == "xml" || configExt == "xml")
{
if (modelExt == "xml" || configExt == "bin")
std::swap(model, config);
return readNetFromModelOptimizer(config, model);
}
if (framework == "onnx" || modelExt == "onnx")
{
return readNetFromONNX(model);
}
CV_Error(Error::StsError, "Cannot determine an origin framework of files: " + model + (config.empty() ? "" : ", " + config));
}
Net readNet(const String& _framework, const std::vector<uchar>& bufferModel,
const std::vector<uchar>& bufferConfig)
{
String framework = toLowerCase(_framework);
if (framework == "caffe")
return readNetFromCaffe(bufferConfig, bufferModel);
else if (framework == "tensorflow")
return readNetFromTensorflow(bufferModel, bufferConfig);
else if (framework == "darknet")
return readNetFromDarknet(bufferConfig, bufferModel);
else if (framework == "torch")
CV_Error(Error::StsNotImplemented, "Reading Torch models from buffers");
else if (framework == "dldt")
return readNetFromModelOptimizer(bufferConfig, bufferModel);
else if (framework == "tflite")
return readNetFromTFLite(bufferModel);
CV_Error(Error::StsError, "Cannot determine an origin framework with a name " + framework);
}
Net readNetFromModelOptimizer(const String& xml, const String& bin)
{
return Net::readFromModelOptimizer(xml, bin);
}
Net readNetFromModelOptimizer(const std::vector<uchar>& bufferCfg, const std::vector<uchar>& bufferModel)
{
return Net::readFromModelOptimizer(bufferCfg, bufferModel);
}
Net readNetFromModelOptimizer(
const uchar* bufferModelConfigPtr, size_t bufferModelConfigSize,
const uchar* bufferWeightsPtr, size_t bufferWeightsSize)
{
return Net::readFromModelOptimizer(
bufferModelConfigPtr, bufferModelConfigSize,
bufferWeightsPtr, bufferWeightsSize);
}
CV__DNN_INLINE_NS_END
}} // namespace cv::dnn