mirror of
https://github.com/opencv/opencv.git
synced 2024-12-15 01:39:10 +08:00
76350cd30f
TFLite models importer * initial commit * Refactor TFLiteImporter * Better FlatBuffers detection * Add permute before 4D->3D reshape * Track layers layout * TFLite Convolution2DTransposeBias layer * Skip TFLite tests without FlatBuffers * Fix check of FlatBuffers in tests. Add readNetFromTFLite from buffer * TFLite Max Unpooling test * Add skip for TFLite unpooling test * Revert DW convolution workaround * Fix ObjC bindings * Better errors handling * Regenerate TFLite schema using flatc * dnn(tflite): more checks, better logging * Checks for unimplemented fusion. Fix tests
100 lines
3.7 KiB
C++
100 lines
3.7 KiB
C++
// This file is part of OpenCV project.
|
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
|
// of this distribution and at http://opencv.org/license.html.
|
|
|
|
#include "precomp.hpp"
|
|
|
|
|
|
namespace cv {
|
|
namespace dnn {
|
|
CV__DNN_INLINE_NS_BEGIN
|
|
|
|
|
|
Net readNet(const String& _model, const String& _config, const String& _framework)
|
|
{
|
|
String framework = toLowerCase(_framework);
|
|
String model = _model;
|
|
String config = _config;
|
|
const std::string modelExt = model.substr(model.rfind('.') + 1);
|
|
const std::string configExt = config.substr(config.rfind('.') + 1);
|
|
if (framework == "caffe" || modelExt == "caffemodel" || configExt == "caffemodel" || modelExt == "prototxt" || configExt == "prototxt")
|
|
{
|
|
if (modelExt == "prototxt" || configExt == "caffemodel")
|
|
std::swap(model, config);
|
|
return readNetFromCaffe(config, model);
|
|
}
|
|
if (framework == "tensorflow" || modelExt == "pb" || configExt == "pb" || modelExt == "pbtxt" || configExt == "pbtxt")
|
|
{
|
|
if (modelExt == "pbtxt" || configExt == "pb")
|
|
std::swap(model, config);
|
|
return readNetFromTensorflow(model, config);
|
|
}
|
|
if (framework == "tflite" || modelExt == "tflite")
|
|
{
|
|
return readNetFromTFLite(model);
|
|
}
|
|
if (framework == "torch" || modelExt == "t7" || modelExt == "net" || configExt == "t7" || configExt == "net")
|
|
{
|
|
return readNetFromTorch(model.empty() ? config : model);
|
|
}
|
|
if (framework == "darknet" || modelExt == "weights" || configExt == "weights" || modelExt == "cfg" || configExt == "cfg")
|
|
{
|
|
if (modelExt == "cfg" || configExt == "weights")
|
|
std::swap(model, config);
|
|
return readNetFromDarknet(config, model);
|
|
}
|
|
if (framework == "dldt" || modelExt == "bin" || configExt == "bin" || modelExt == "xml" || configExt == "xml")
|
|
{
|
|
if (modelExt == "xml" || configExt == "bin")
|
|
std::swap(model, config);
|
|
return readNetFromModelOptimizer(config, model);
|
|
}
|
|
if (framework == "onnx" || modelExt == "onnx")
|
|
{
|
|
return readNetFromONNX(model);
|
|
}
|
|
CV_Error(Error::StsError, "Cannot determine an origin framework of files: " + model + (config.empty() ? "" : ", " + config));
|
|
}
|
|
|
|
Net readNet(const String& _framework, const std::vector<uchar>& bufferModel,
|
|
const std::vector<uchar>& bufferConfig)
|
|
{
|
|
String framework = toLowerCase(_framework);
|
|
if (framework == "caffe")
|
|
return readNetFromCaffe(bufferConfig, bufferModel);
|
|
else if (framework == "tensorflow")
|
|
return readNetFromTensorflow(bufferModel, bufferConfig);
|
|
else if (framework == "darknet")
|
|
return readNetFromDarknet(bufferConfig, bufferModel);
|
|
else if (framework == "torch")
|
|
CV_Error(Error::StsNotImplemented, "Reading Torch models from buffers");
|
|
else if (framework == "dldt")
|
|
return readNetFromModelOptimizer(bufferConfig, bufferModel);
|
|
else if (framework == "tflite")
|
|
return readNetFromTFLite(bufferModel);
|
|
CV_Error(Error::StsError, "Cannot determine an origin framework with a name " + framework);
|
|
}
|
|
|
|
Net readNetFromModelOptimizer(const String& xml, const String& bin)
|
|
{
|
|
return Net::readFromModelOptimizer(xml, bin);
|
|
}
|
|
|
|
Net readNetFromModelOptimizer(const std::vector<uchar>& bufferCfg, const std::vector<uchar>& bufferModel)
|
|
{
|
|
return Net::readFromModelOptimizer(bufferCfg, bufferModel);
|
|
}
|
|
|
|
Net readNetFromModelOptimizer(
|
|
const uchar* bufferModelConfigPtr, size_t bufferModelConfigSize,
|
|
const uchar* bufferWeightsPtr, size_t bufferWeightsSize)
|
|
{
|
|
return Net::readFromModelOptimizer(
|
|
bufferModelConfigPtr, bufferModelConfigSize,
|
|
bufferWeightsPtr, bufferWeightsSize);
|
|
}
|
|
|
|
|
|
CV__DNN_INLINE_NS_END
|
|
}} // namespace cv::dnn
|