Merge pull request #23161 from dkurt:dnn_tflite

TFLite models importer

* initial commit

* Refactor TFLiteImporter

* Better FlatBuffers detection

* Add permute before 4D->3D reshape

* Track layers layout

* TFLite Convolution2DTransposeBias layer

* Skip TFLite tests without FlatBuffers

* Fix check of FlatBuffers in tests. Add readNetFromTFLite from buffer

* TFLite Max Unpooling test

* Add skip for TFLite unpooling test

* Revert DW convolution workaround

* Fix ObjC bindings

* Better errors handling

* Regenerate TFLite schema using flatc

* dnn(tflite): more checks, better logging

* Checks for unimplemented fusion. Fix tests
This commit is contained in:
Dmitry Kurtaev 2023-02-13 17:00:20 +03:00 committed by GitHub
parent 47293f28cf
commit 76350cd30f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 2215 additions and 2 deletions

View File

@ -471,6 +471,9 @@ OCV_OPTION(WITH_OBSENSOR "Include obsensor support (Orbbec RGB-D modules: Astra+
OCV_OPTION(WITH_CANN "Include CANN support" OFF
VISIBLE_IF TRUE
VERIFY HAVE_CANN)
OCV_OPTION(WITH_FLATBUFFERS "Include FlatBuffers support" OFF
VISIBLE_IF TRUE
VERIFY HAVE_FLATBUFFERS)
# OpenCV build components
# ===================================================
@ -750,6 +753,7 @@ include(cmake/OpenCVFindLibsVideo.cmake)
include(cmake/OpenCVFindLibsPerf.cmake)
include(cmake/OpenCVFindLAPACK.cmake)
include(cmake/OpenCVFindProtobuf.cmake)
include(cmake/OpenCVFindFlatBuffers.cmake)
if(WITH_TENGINE)
include(cmake/OpenCVFindTengine.cmake)
endif()

View File

@ -0,0 +1,15 @@
set(HAVE_FLATBUFFERS FALSE)
if(NOT WITH_FLATBUFFERS)
return()
endif()
list(APPEND CUSTOM_STATUS flatbuffers)
find_package(flatbuffers QUIET)
if(flatbuffers_FOUND)
set(HAVE_FLATBUFFERS 1)
list(APPEND CUSTOM_STATUS_flatbuffers " FlatBuffers:" "${flatbuffers_VERSION}")
else()
list(APPEND CUSTOM_STATUS_flatbuffers " FlatBuffers:" "NO")
endif()

View File

@ -133,6 +133,17 @@ if(NOT BUILD_PROTOBUF)
list(APPEND include_dirs ${Protobuf_INCLUDE_DIRS})
endif()
if(HAVE_FLATBUFFERS)
list(APPEND libs flatbuffers::flatbuffers)
list(APPEND fw_srcs "${CMAKE_CURRENT_BINARY_DIR}/schema_generated.h")
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/schema_generated.h"
COMMAND flatbuffers::flatc --cpp -o "${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_LIST_DIR}/src/tflite/schema.fbs")
ocv_target_compile_definitions(${the_module} PRIVATE "HAVE_FLATBUFFERS=1")
endif()
set(sources_options "")
list(APPEND libs ${LAPACK_LIBRARIES})
@ -280,3 +291,9 @@ if(TARGET ocv.3rdparty.cann AND OPENCV_TEST_DNN_CANN)
ocv_target_link_libraries(opencv_test_dnn ocv.3rdparty.cann)
endif()
endif()
if(HAVE_FLATBUFFERS)
if(TARGET opencv_test_dnn)
ocv_target_compile_definitions(opencv_test_dnn PRIVATE "HAVE_FLATBUFFERS=1")
endif()
endif()

View File

@ -953,6 +953,26 @@ CV__DNN_INLINE_NS_BEGIN
CV_EXPORTS Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
const char *bufferConfig = NULL, size_t lenConfig = 0);
/** @brief Reads a network model stored in <a href="https://www.tensorflow.org/lite">TFLite</a> framework's format.
* @param model path to the .tflite file with binary flatbuffers description of the network architecture
* @returns Net object.
*/
CV_EXPORTS_W Net readNetFromTFLite(const String &model);
/** @brief Reads a network model stored in <a href="https://www.tensorflow.org/lite">TFLite</a> framework's format.
* @param bufferModel buffer containing the content of the tflite file
* @returns Net object.
*/
CV_EXPORTS_W Net readNetFromTFLite(const std::vector<uchar>& bufferModel);
/** @brief Reads a network model stored in <a href="https://www.tensorflow.org/lite">TFLite</a> framework's format.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferModel buffer containing the content of the tflite file
* @param lenModel length of bufferModel
*/
CV_EXPORTS Net readNetFromTFLite(const char *bufferModel, size_t lenModel);
/**
* @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format.
* @param model path to the file, dumped from Torch by using torch.save() function.

View File

@ -8,7 +8,9 @@
"(Net*)readNetFromONNX:(NSString*)onnxFile" : { "readNetFromONNX" : {"name" : "readNetFromONNXFile"} },
"(Net*)readNetFromONNX:(ByteVector*)buffer" : { "readNetFromONNX" : {"name" : "readNetFromONNXBuffer"} },
"(Net*)readNetFromTensorflow:(NSString*)model config:(NSString*)config" : { "readNetFromTensorflow" : {"name" : "readNetFromTensorflowFile"} },
"(Net*)readNetFromTensorflow:(ByteVector*)bufferModel bufferConfig:(ByteVector*)bufferConfig" : { "readNetFromTensorflow" : {"name" : "readNetFromTensorflowBuffer"} }
"(Net*)readNetFromTensorflow:(ByteVector*)bufferModel bufferConfig:(ByteVector*)bufferConfig" : { "readNetFromTensorflow" : {"name" : "readNetFromTensorflowBuffer"} },
"(Net*)readNetFromTFLite:(NSString*)model" : { "readNetFromTFLite" : {"name" : "readNetFromTFLiteFile"} },
"(Net*)readNetFromTFLite:(ByteVector*)buffer" : { "readNetFromTFLite" : {"name" : "readNetFromTFLiteBuffer"} }
},
"Net": {
"(void)forward:(NSMutableArray<Mat*>*)outputBlobs outputName:(NSString*)outputName" : { "forward" : {"name" : "forwardOutputBlobs"} },

View File

@ -29,6 +29,10 @@ Net readNet(const String& _model, const String& _config, const String& _framewor
std::swap(model, config);
return readNetFromTensorflow(model, config);
}
if (framework == "tflite" || modelExt == "tflite")
{
return readNetFromTFLite(model);
}
if (framework == "torch" || modelExt == "t7" || modelExt == "net" || configExt == "t7" || configExt == "net")
{
return readNetFromTorch(model.empty() ? config : model);
@ -66,6 +70,8 @@ Net readNet(const String& _framework, const std::vector<uchar>& bufferModel,
CV_Error(Error::StsNotImplemented, "Reading Torch models from buffers");
else if (framework == "dldt")
return readNetFromModelOptimizer(bufferConfig, bufferModel);
else if (framework == "tflite")
return readNetFromTFLite(bufferModel);
CV_Error(Error::StsError, "Cannot determine an origin framework with a name " + framework);
}

View File

@ -0,0 +1,41 @@
// source: https://github.com/tensorflow/tensorflow/blob/b2f5959ff823a8ed5bf4883e785f8f96d4253a8b/tensorflow/lite/core/c/builtin_op_data.h
typedef enum {
kTfLitePaddingUnknown = 0,
kTfLitePaddingSame,
kTfLitePaddingValid,
} TfLitePadding;
typedef enum {
kTfLiteActNone = 0,
kTfLiteActRelu,
kTfLiteActReluN1To1, // min(max(-1, x), 1)
kTfLiteActRelu6, // min(max(0, x), 6)
kTfLiteActTanh,
kTfLiteActSignBit,
kTfLiteActSigmoid,
} TfLiteFusedActivation;
typedef struct {
int width;
int height;
int width_offset;
int height_offset;
} TfLitePaddingValues;
typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
int filter_width;
int filter_height;
TfLiteFusedActivation activation;
struct {
TfLitePaddingValues padding;
} computed;
} TfLitePoolParams;
typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
} TfLiteTransposeConvParams;

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,644 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#ifdef HAVE_FLATBUFFERS
#include "schema_generated.h"
#include "builtin_op_data.h"
#endif
#include <opencv2/core/utils/logger.defines.hpp>
#undef CV_LOG_STRIP_LEVEL
#define CV_LOG_STRIP_LEVEL CV_LOG_LEVEL_VERBOSE + 1
#include <opencv2/core/utils/logger.hpp>
namespace cv {
namespace dnn {
CV__DNN_INLINE_NS_BEGIN
#ifdef HAVE_FLATBUFFERS
using namespace opencv_tflite;
// This values are used to indicate layer output's data layout where it's possible.
// Approach is similar to TensorFlow importer but TFLite models do not have explicit
// layout field "data_format". So we consider that all 4D inputs are in NHWC data layout.
enum DataLayout
{
DATA_LAYOUT_NHWC,
DATA_LAYOUT_NCHW,
DATA_LAYOUT_NDHWC,
DATA_LAYOUT_UNKNOWN,
DATA_LAYOUT_PLANAR // 2-dimensional outputs (matmul, flatten, reshape to 2d)
};
class TFLiteImporter {
public:
TFLiteImporter(Net& net, const char* modelBuffer, size_t bufSize);
private:
const opencv_tflite::Model* model;
const flatbuffers::Vector<flatbuffers::Offset<opencv_tflite::Tensor> >* modelTensors;
std::map<int, Mat> allTensors;
Net& dstNet;
// This is a vector of pairs (layerId, outputId) where we iterate over
// indices from TFLite notation and get created OpenCV layers.
std::map<int, std::pair<int, int> > layerIds;
// Tracking of layouts for layers outputs.
std::vector<DataLayout> layouts;
void populateNet();
// Wrap TFLite Tensor to OpenCV Mat without data copying
Mat parseTensor(const Tensor& tensor);
typedef void (TFLiteImporter::*TFLiteImporterNodeParser)(const Operator&, const std::string&, LayerParams&);
typedef std::map<std::string, TFLiteImporterNodeParser> DispatchMap;
const DispatchMap dispatch;
static DispatchMap buildDispatchMap();
void parseConvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseDWConvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parsePadding(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseEltwise(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parsePooling(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parsePoolingWithArgmax(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseUnpooling(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseReshape(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseConcat(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseResize(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseDeconvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams);
int addPermuteLayer(const std::vector<int>& order, const std::string& permName, const std::pair<int, int>& inpId);
};
Mat TFLiteImporter::parseTensor(const Tensor& tensor)
{
const auto tensor_shape = tensor.shape();
CV_Assert(tensor_shape);
std::vector<int> shape(tensor_shape->begin(), tensor_shape->end());
int bufferIdx = tensor.buffer();
CV_Assert(bufferIdx != 0); // 0th buffer is a no-data buffer
const Buffer* buffer = model->buffers()->Get(bufferIdx);
CV_Assert(buffer);
const auto buffer_data = buffer->data();
CV_Assert(buffer_data);
const void* data = buffer_data->data();
int dtype = -1;
switch (tensor.type()) {
case TensorType_FLOAT32:
dtype = CV_32F;
break;
case TensorType_INT32:
dtype = CV_32S;
break;
case TensorType_FLOAT16:
dtype = CV_16S;
break;
default:
CV_Error(Error::StsNotImplemented, format("Parse tensor with type %s", EnumNameTensorType(tensor.type())));
}
return Mat(shape, dtype, const_cast<void*>(data));
}
TFLiteImporter::TFLiteImporter(Net& dstNet, const char* modelBuffer, size_t bufSize)
: dstNet(dstNet), dispatch(buildDispatchMap())
{
flatbuffers::Verifier verifier((const uint8_t*)modelBuffer, bufSize);
if (!VerifyModelBuffer(verifier)) {
CV_Error(Error::StsError, "DNN/TFLite: model is incorrect");
}
model = GetModel(modelBuffer);
CV_Assert(model);
CV_Assert(model->subgraphs());
CV_Assert(model->buffers());
CV_CheckEQ(model->subgraphs()->size(), 1, "");
modelTensors = model->subgraphs()->Get(0)->tensors();
CV_Assert(modelTensors);
for (int i = 0; i < modelTensors->size(); ++i) {
const Tensor* tensor = modelTensors->Get(i);
CV_Assert(tensor);
if (tensor->buffer() != 0) {
allTensors[i] = parseTensor(*tensor);
}
}
populateNet();
}
DataLayout estimateLayout(const Tensor& t)
{
const auto t_shape = t.shape();
CV_Assert(t_shape);
switch (t_shape->size()) {
case 5: return DATA_LAYOUT_NDHWC;
case 4: return DATA_LAYOUT_NHWC;
case 2: return DATA_LAYOUT_PLANAR;
default: return DATA_LAYOUT_UNKNOWN;
}
}
void TFLiteImporter::populateNet()
{
CV_Assert(model);
const auto model_subgraphs = model->subgraphs();
CV_Assert(model_subgraphs);
const SubGraph* subgraph = model_subgraphs->Get(0);
CV_Assert(subgraph);
const auto subgraph_inputs = subgraph->inputs();
CV_Assert(subgraph_inputs);
const auto subgraph_operators = subgraph->operators();
CV_Assert(subgraph_operators);
const auto opCodes = model->operator_codes();
CV_Assert(opCodes);
CV_Assert(modelTensors);
layouts.resize(modelTensors->size(), DATA_LAYOUT_UNKNOWN);
size_t subgraph_inputs_size = subgraph_inputs->size();
for (size_t i = 0; i < subgraph_inputs_size; ++i)
{
int idx = subgraph_inputs->Get(i);
layerIds[idx] = std::make_pair(0, i);
const auto tensor = modelTensors->Get(idx);
if (!tensor)
CV_Error(Error::StsError, cv::format("DNN/TFLite: subgraph input %d (%d) is NULL", (int)i, idx));
layouts[idx] = estimateLayout(*tensor);
}
const auto& all_operators = *subgraph_operators;
const size_t all_operators_size = all_operators.size();
for (size_t op_idx = 0; op_idx < all_operators_size; ++op_idx)
{
const auto op = all_operators[op_idx];
CV_Assert(op);
const auto op_inputs = op->inputs();
CV_Assert(op_inputs);
const auto op_outputs = op->outputs();
CV_Assert(op_outputs);
int idx = op->opcode_index();
LayerParams layerParams;
layerParams.name = modelTensors->Get(op_outputs->Get(0))->name()->str();
std::string type = EnumNameBuiltinOperator(BuiltinOperator(opCodes->Get(idx)->deprecated_builtin_code()));
if (type == "CUSTOM") {
type = opCodes->Get(idx)->custom_code()->str();
}
CV_LOG_DEBUG(NULL, "DNN/TFLite: processing operator (" << op_idx << "/" << all_operators_size << ") with " << op_inputs->size() << " inputs: "
<< cv::format("[%s]:(%s)", type.c_str(), layerParams.name.c_str()));
try
{
if (type == "DEQUANTIZE") {
// Convert from FP16 to FP32
Mat data = allTensors[op_inputs->Get(0)];
Mat dataFP32;
convertFp16(data, dataFP32);
allTensors[op_outputs->Get(0)] = dataFP32;
continue;
}
DispatchMap::const_iterator iter = dispatch.find(type);
if (iter == dispatch.end())
CV_Error(Error::StsNotImplemented, "Unsupported operator type " + type);
CALL_MEMBER_FN(*this, iter->second)(*op, type, layerParams);
// Collect input blobs
std::vector<int> layerInputs;
std::vector<DataLayout> inpLayouts;
for (int idx : *op_inputs) {
if (layerIds.find(idx) != layerIds.end()) {
layerInputs.push_back(idx);
inpLayouts.push_back(layouts[idx]);
continue; // Output from a different layer
}
Mat blob = allTensors[idx];
layerParams.blobs.push_back(blob.u ? blob : blob.clone()); // some tensors are owned by OpenCV
}
int layerId = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
// Connect layer to inputs
int i = 0;
for (int idx : layerInputs) {
auto it = layerIds.find(idx);
CV_Assert(it != layerIds.end());
dstNet.connect(it->second.first, it->second.second, layerId, i++);
}
// Predict output layout. Some layer-specific parsers may set them explicitly.
// Otherwise, propagate input layout.
if (layouts[op_outputs->Get(0)] == DATA_LAYOUT_UNKNOWN) {
DataLayout predictedLayout = DATA_LAYOUT_UNKNOWN;
for (auto layout : inpLayouts) {
if (layout != DATA_LAYOUT_UNKNOWN) {
if (predictedLayout == DATA_LAYOUT_UNKNOWN)
predictedLayout = layout;
else if (predictedLayout != layout) {
predictedLayout = DATA_LAYOUT_UNKNOWN;
break;
}
}
}
layouts[op_outputs->Get(0)] = predictedLayout;
}
// Register outputs
i = 0;
for (int idx : *op_outputs) {
layerIds[idx] = std::make_pair(layerId, i++);
}
}
catch (const cv::Exception& e)
{
CV_LOG_ERROR(NULL, "DNN/TFLite: Problem during import of operator "
<< cv::format("[%s]:(%s)", type.c_str(), layerParams.name.c_str())
<< " (" << op_idx << "/" << all_operators_size << "). Exception: " << e.what());
if (DNN_DIAGNOSTICS_RUN)
{
continue;
}
throw;
}
}
}
TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
{
static DispatchMap dispatch;
if (!dispatch.empty())
return dispatch;
dispatch["CONV_2D"] = &TFLiteImporter::parseConvolution;
dispatch["DEPTHWISE_CONV_2D"] = &TFLiteImporter::parseDWConvolution;
dispatch["RELU"] = dispatch["ADD"] = dispatch["MUL"] = dispatch["PRELU"] =
dispatch["HARD_SWISH"] = dispatch["LOGISTIC"] = &TFLiteImporter::parseEltwise;
dispatch["MAX_POOL_2D"] = dispatch["AVERAGE_POOL_2D"] = &TFLiteImporter::parsePooling;
dispatch["MaxPoolingWithArgmax2D"] = &TFLiteImporter::parsePoolingWithArgmax;
dispatch["MaxUnpooling2D"] = &TFLiteImporter::parseUnpooling;
dispatch["PAD"] = &TFLiteImporter::parsePadding;
dispatch["RESHAPE"] = &TFLiteImporter::parseReshape;
dispatch["CONCATENATION"] = &TFLiteImporter::parseConcat;
dispatch["RESIZE_BILINEAR"] = &TFLiteImporter::parseResize;
dispatch["Convolution2DTransposeBias"] = &TFLiteImporter::parseDeconvolution;
return dispatch;
}
void TFLiteImporter::parseConvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Convolution";
auto options = reinterpret_cast<const Conv2DOptions*>(op.builtin_options());
if (options->fused_activation_function() != ActivationFunctionType_NONE) {
CV_Error(Error::StsNotImplemented, "Convolution with fused activation");
}
layerParams.set("pad_mode", EnumNamePadding(options->padding()));
layerParams.set("stride_w", options->stride_w());
layerParams.set("stride_h", options->stride_h());
layerParams.set("dilation_w", options->dilation_w_factor());
layerParams.set("dilation_h", options->dilation_h_factor());
// Get filter size
int filterIdx = op.inputs()->Get(1);
Mat filter = allTensors[filterIdx];
int oc = filter.size[0];
int kh = filter.size[1];
int kw = filter.size[2];
int ic = filter.size[3];
layerParams.set("kernel_w", kw);
layerParams.set("kernel_h", kh);
layerParams.set("num_output", oc);
// Reorder filter data from OHWI to OIHW and change shape correspondingly.
filter = allTensors[filterIdx] = filter.reshape(1, {oc, ic, kh, kw});
CV_CheckTypeEQ(filter.type(), CV_32F, "");
Mat filterCopy = filter.clone();
float* data = filterCopy.ptr<float>();
float* dstData = filter.ptr<float>();
int total = oc * ic * kh * kw;
for (int i_oc = 0; i_oc < oc; i_oc++) {
for (int i_ic = 0; i_ic < ic; i_ic++) {
for (int i_h = 0; i_h < kh; i_h++) {
for (int i_w = 0; i_w < kw; i_w++) {
int dst_i = kw * (kh * (ic * i_oc + i_ic) + i_h) + i_w;
int src_i = ic * (kw * (kh * i_oc + i_h) + i_w) + i_ic;
CV_CheckLT(dst_i, total, "");
CV_CheckLT(src_i, total, "");
dstData[dst_i] = data[src_i];
}
}
}
}
}
void TFLiteImporter::parseDWConvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Convolution";
auto options = reinterpret_cast<const DepthwiseConv2DOptions*>(op.builtin_options());
if (options->fused_activation_function() != ActivationFunctionType_NONE) {
CV_Error(Error::StsNotImplemented, "Depthwise convolution with fused activation");
}
layerParams.set("pad_mode", EnumNamePadding(options->padding()));
layerParams.set("stride_w", options->stride_w());
layerParams.set("stride_h", options->stride_h());
layerParams.set("dilation_w", options->dilation_w_factor());
layerParams.set("dilation_h", options->dilation_h_factor());
int filterIdx = op.inputs()->Get(1);
Mat filter = allTensors[filterIdx];
int kh = filter.size[1];
int kw = filter.size[2];
int oc = filter.size[3];
layerParams.set("kernel_w", kw);
layerParams.set("kernel_h", kh);
layerParams.set("num_output", oc);
layerParams.set("group", oc);
filter = allTensors[filterIdx] = filter.reshape(1, {oc, 1, kh, kw});
cv::transpose(filter.reshape(1, kh * kw).clone(), filter.reshape(1, oc));
}
void TFLiteImporter::parsePadding(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Padding";
Mat paddings = allTensors[op.inputs()->Get(1)];
CV_CheckTypeEQ(paddings.type(), CV_32S, "");
// N H W C
// 0 1 2 3 4 5 6 7
std::swap(paddings.at<int32_t>(2), paddings.at<int32_t>(6));
std::swap(paddings.at<int32_t>(3), paddings.at<int32_t>(7));
// N C W H
// 0 1 2 3 4 5 6 7
std::swap(paddings.at<int32_t>(4), paddings.at<int32_t>(6));
std::swap(paddings.at<int32_t>(5), paddings.at<int32_t>(7));
// N C H W
// 0 1 2 3 4 5 6 7
layerParams.set("paddings", DictValue::arrayInt<int32_t*>((int32_t*)paddings.data, paddings.total()));
}
void TFLiteImporter::parseEltwise(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
if (opcode == "PRELU") {
layerParams.type = "PReLU";
} else if (opcode == "RELU") {
layerParams.type = "ReLU";
} else if (opcode == "ADD") {
auto options = reinterpret_cast<const AddOptions*>(op.builtin_options());
if (options->fused_activation_function() != ActivationFunctionType_NONE) {
CV_Error(Error::StsNotImplemented, "Add with fused activation");
}
layerParams.type = "Eltwise";
layerParams.set("operation", "sum");
} else if (opcode == "MUL") {
auto options = reinterpret_cast<const MulOptions*>(op.builtin_options());
if (options->fused_activation_function() != ActivationFunctionType_NONE) {
CV_Error(Error::StsNotImplemented, "Mul with fused activation");
}
layerParams.type = "Eltwise";
layerParams.set("operation", "prod");
} else if (opcode == "HARD_SWISH") {
layerParams.type = "HardSwish";
} else if (opcode == "LOGISTIC") {
layerParams.type = "Sigmoid";
} else {
CV_Error(Error::StsNotImplemented, "Unknown eltwise operator opcode: " + opcode);
}
}
void TFLiteImporter::parsePooling(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Pooling";
auto options = reinterpret_cast<const Pool2DOptions*>(op.builtin_options());
if (options->fused_activation_function() != ActivationFunctionType_NONE) {
CV_Error(Error::StsNotImplemented, "Pooling with fused activation");
}
layerParams.set("pad_mode", EnumNamePadding(options->padding()));
layerParams.set("stride_w", options->stride_w());
layerParams.set("stride_h", options->stride_h());
layerParams.set("kernel_w", options->filter_width());
layerParams.set("kernel_h", options->filter_height());
if (opcode == "MAX_POOL_2D")
layerParams.set("pool", "max");
else if (opcode == "AVERAGE_POOL_2D")
layerParams.set("pool", "ave");
else
CV_Error(Error::StsNotImplemented, "Pool type selection for " + opcode);
}
void TFLiteImporter::parsePoolingWithArgmax(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Pooling";
CV_CheckLE(op.custom_options()->size(), sizeof(TfLitePoolParams), "");
const auto* params = reinterpret_cast<const TfLitePoolParams*>(op.custom_options()->Data());
if (params->activation != kTfLiteActNone) {
CV_Error(Error::StsNotImplemented, "Argmax pooling with fused activation");
}
if (params->padding != kTfLitePaddingUnknown)
layerParams.set("pad_mode", params->padding == kTfLitePaddingSame ? "SAME" : "VALID");
layerParams.set("stride_w", params->stride_width);
layerParams.set("stride_h", params->stride_height);
layerParams.set("kernel_w", params->filter_width);
layerParams.set("kernel_h", params->filter_height);
layerParams.set("pool", "max");
}
void TFLiteImporter::parseUnpooling(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "MaxUnpool";
CV_CheckLE(op.custom_options()->size(), sizeof(TfLitePoolParams), "");
const auto* params = reinterpret_cast<const TfLitePoolParams*>(op.custom_options()->Data());
if (params->activation != kTfLiteActNone) {
CV_Error(Error::StsNotImplemented, "Unpooling with fused activation");
}
layerParams.set("pool_stride_w", params->stride_width);
layerParams.set("pool_stride_h", params->stride_height);
layerParams.set("pool_k_w", params->filter_width);
layerParams.set("pool_k_h", params->filter_height);
layerParams.set("pool_pad_w", 0);
layerParams.set("pool_pad_h", 0);
}
void TFLiteImporter::parseReshape(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
DataLayout inpLayout = layouts[op.inputs()->Get(0)];
if (inpLayout == DATA_LAYOUT_NHWC) {
// Permute to NCHW
int permId = addPermuteLayer({0, 2, 3, 1}, layerParams.name + "/permute", layerIds[op.inputs()->Get(0)]); // NCHW -> NHWC
layerIds[op.inputs()->Get(0)] = std::make_pair(permId, 0);
layouts[op.outputs()->Get(0)] = DATA_LAYOUT_NCHW;
}
layerParams.type = "Reshape";
auto options = reinterpret_cast<const ReshapeOptions*>(op.builtin_options());
std::vector<int> shape(options->new_shape()->begin(), options->new_shape()->end());
// std::swap(shape[1], shape[2]);
layerParams.set("dim", DictValue::arrayInt<int*>(shape.data(), shape.size()));
}
void TFLiteImporter::parseConcat(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Concat";
auto options = reinterpret_cast<const ConcatenationOptions*>(op.builtin_options());
if (options->fused_activation_function() != ActivationFunctionType_NONE) {
CV_Error(Error::StsNotImplemented, "Concat with fused activation");
}
int axis = options->axis();
DataLayout inpLayout = layouts[op.inputs()->Get(0)];
if (inpLayout == DATA_LAYOUT_NHWC) {
// OpenCV works in NCHW data layout. So change the axis correspondingly.
CV_Check(axis, -4 < axis && axis < 4, "");
int remap[] = {0, 2, 3, 1};
axis = axis > 0 ? axis : 4 + axis;
axis = remap[axis];
}
layerParams.set("axis", axis);
}
void TFLiteImporter::parseResize(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Resize";
auto options = reinterpret_cast<const ResizeBilinearOptions*>(op.builtin_options());
layerParams.set("interpolation", "bilinear");
layerParams.set("align_corners", options->align_corners());
layerParams.set("half_pixel_centers", options->half_pixel_centers());
Mat shape = allTensors[op.inputs()->Get(1)].reshape(1, 1);
layerParams.set("height", shape.at<int>(0, 0));
layerParams.set("width", shape.at<int>(0, 1));
}
int TFLiteImporter::addPermuteLayer(const std::vector<int>& order, const std::string& permName,
const std::pair<int, int>& inpId)
{
LayerParams permLP;
permLP.set("order", DictValue::arrayInt<const int*>(order.data(), order.size()));
int permId = dstNet.addLayer(permName, "Permute", permLP);
dstNet.connect(inpId.first, inpId.second, permId, 0);
return permId;
}
void TFLiteImporter::parseDeconvolution(const Operator& op, const std::string& opcode, LayerParams& layerParams) {
layerParams.type = "Deconvolution";
CV_CheckLE(op.custom_options()->size(), sizeof(TfLiteTransposeConvParams), "");
const auto* params = reinterpret_cast<const TfLiteTransposeConvParams*>(op.custom_options()->Data());
if (params->padding != kTfLitePaddingUnknown)
layerParams.set("pad_mode", params->padding == kTfLitePaddingSame ? "SAME" : "VALID");
layerParams.set("stride_w", params->stride_width);
layerParams.set("stride_h", params->stride_height);
// Get filter size
int filterIdx = op.inputs()->Get(1);
Mat filter = allTensors[filterIdx];
int oc = filter.size[0];
int kh = filter.size[1];
int kw = filter.size[2];
int ic = filter.size[3];
layerParams.set("kernel_w", kw);
layerParams.set("kernel_h", kh);
layerParams.set("num_output", oc);
// Add adjust padding similar to TensorFlow (see tf_importer)
const auto* outShape = modelTensors->Get(op.outputs()->Get(0))->shape();
const int outH = outShape->Get(1);
const int outW = outShape->Get(2);
if (params->padding == kTfLitePaddingSame)
{
layerParams.set("adj_w", (outW - 1) % params->stride_width);
layerParams.set("adj_h", (outH - 1) % params->stride_height);
}
else if (params->padding == kTfLitePaddingValid)
{
layerParams.set("adj_w", (outW - kw) % params->stride_width);
layerParams.set("adj_h", (outH - kh) % params->stride_height);
}
// Reorder filter data from OHWI to IOHW and change shape correspondingly.
filter = allTensors[filterIdx] = filter.reshape(1, {ic, oc, kh, kw});
CV_CheckTypeEQ(filter.type(), CV_32F, "");
Mat filterCopy = filter.clone();
float* data = filterCopy.ptr<float>();
float* dstData = filter.ptr<float>();
int total = oc * ic * kh * kw;
for (int i_oc = 0; i_oc < oc; i_oc++) {
for (int i_ic = 0; i_ic < ic; i_ic++) {
for (int i_h = 0; i_h < kh; i_h++) {
for (int i_w = 0; i_w < kw; i_w++) {
int dst_i = kw * (kh * (oc * i_ic + i_oc) + i_h) + i_w;
int src_i = ic * (kw * (kh * i_oc + i_h) + i_w) + i_ic;
CV_CheckLT(dst_i, total, "");
CV_CheckLT(src_i, total, "");
dstData[dst_i] = data[src_i];
}
}
}
}
}
Net readNetFromTFLite(const String &modelPath) {
Net net;
std::vector<char> content;
const std::ios::openmode mode = std::ios::in | std::ios::binary;
std::ifstream ifs(modelPath, mode);
if (!ifs.is_open())
CV_Error(Error::StsError, cv::format("DNN/TFLite: can't open model file '%s'", modelPath.c_str()));
ifs.seekg(0, std::ios::end);
const size_t sz = ifs.tellg();
CV_Assert(sz > 0);
content.resize(sz);
ifs.seekg(0, std::ios::beg);
ifs.read(content.data(), sz);
CV_Assert(!ifs.bad());
TFLiteImporter(net, content.data(), content.size());
return net;
}
Net readNetFromTFLite(const std::vector<uchar>& bufferModel) {
return readNetFromTFLite((const char*)bufferModel.data(), bufferModel.size());
}
Net readNetFromTFLite(const char *bufferModel, size_t bufSize) {
Net net;
TFLiteImporter(net, bufferModel, bufSize);
return net;
}
#else // HAVE_FLATBUFFERS
#define DNN_TFLITE_UNSUPPORTED() CV_Error(Error::StsError, "DNN/TFLite: Build OpenCV with FlatBuffers to import TFLite models: https://github.com/opencv/opencv/pull/23161")
Net readNetFromTFLite(const String &) {
DNN_TFLITE_UNSUPPORTED();
}
Net readNetFromTFLite(const std::vector<uchar>&) {
DNN_TFLITE_UNSUPPORTED();
}
Net readNetFromTFLite(const char *, size_t) {
DNN_TFLITE_UNSUPPORTED();
}
#endif // HAVE_FLATBUFFERS
CV__DNN_INLINE_NS_END
}} // namespace cv::dnn

View File

@ -0,0 +1,123 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
/*
Test for TFLite models loading
*/
#include "test_precomp.hpp"
#include "npy_blob.hpp"
#include <opencv2/dnn/layer.details.hpp> // CV_DNN_REGISTER_LAYER_CLASS
#include <opencv2/dnn/utils/debug_utils.hpp>
namespace opencv_test
{
using namespace cv;
using namespace cv::dnn;
void testModel(const std::string& modelName, const Mat& input, double norm = 1e-5) {
#ifndef HAVE_FLATBUFFERS
throw SkipTestException("FlatBuffers required for TFLite importer");
#endif
Net net = readNet(findDataFile("dnn/tflite/" + modelName + ".tflite", false));
net.setInput(input);
std::vector<String> outNames = net.getUnconnectedOutLayersNames();
std::vector<Mat> outs;
net.forward(outs, outNames);
ASSERT_EQ(outs.size(), outNames.size());
for (int i = 0; i < outNames.size(); ++i) {
Mat ref = blobFromNPY(findDataFile(format("dnn/tflite/%s_out_%s.npy", modelName.c_str(), outNames[i].c_str())));
normAssert(ref.reshape(1, 1), outs[i].reshape(1, 1), outNames[i].c_str(), norm);
}
}
void testModel(const std::string& modelName, const Size& inpSize, double norm = 1e-5) {
Mat input = imread(findDataFile("cv/shared/lena.png"));
input = blobFromImage(input, 1.0 / 255, inpSize, 0, true);
testModel(modelName, input, norm);
}
// https://google.github.io/mediapipe/solutions/face_mesh
TEST(Test_TFLite, face_landmark)
{
testModel("face_landmark", Size(192, 192), 2e-5);
}
// https://google.github.io/mediapipe/solutions/face_detection
TEST(Test_TFLite, face_detection_short_range)
{
testModel("face_detection_short_range", Size(128, 128));
}
// https://google.github.io/mediapipe/solutions/selfie_segmentation
TEST(Test_TFLite, selfie_segmentation)
{
testModel("selfie_segmentation", Size(256, 256));
}
TEST(Test_TFLite, max_unpooling)
{
#ifndef HAVE_FLATBUFFERS
throw SkipTestException("FlatBuffers required for TFLite importer");
#endif
// Due Max Unpoling is a numerically unstable operation and small difference between frameworks
// might lead to positional difference of maximal elements in the tensor, this test checks
// behavior of Max Unpooling layer only.
Net net = readNet(findDataFile("dnn/tflite/hair_segmentation.tflite", false));
Mat input = imread(findDataFile("cv/shared/lena.png"));
cvtColor(input, input, COLOR_BGR2RGBA);
input = input.mul(Scalar(1, 1, 1, 0));
input = blobFromImage(input, 1.0 / 255);
net.setInput(input);
std::vector<std::vector<Mat> > outs;
net.forward(outs, {"p_re_lu_1", "max_pooling_with_argmax2d", "conv2d_86", "max_unpooling2d_2"});
ASSERT_EQ(outs.size(), 4);
ASSERT_EQ(outs[0].size(), 1);
ASSERT_EQ(outs[1].size(), 2);
ASSERT_EQ(outs[2].size(), 1);
ASSERT_EQ(outs[3].size(), 1);
Mat poolInp = outs[0][0];
Mat poolOut = outs[1][0];
Mat poolIds = outs[1][1];
Mat unpoolInp = outs[2][0];
Mat unpoolOut = outs[3][0];
ASSERT_EQ(poolInp.size, unpoolOut.size);
ASSERT_EQ(poolOut.size, poolIds.size);
ASSERT_EQ(poolOut.size, unpoolInp.size);
for (int c = 0; c < 32; ++c) {
float *poolInpData = poolInp.ptr<float>(0, c);
float *poolOutData = poolOut.ptr<float>(0, c);
float *poolIdsData = poolIds.ptr<float>(0, c);
float *unpoolInpData = unpoolInp.ptr<float>(0, c);
float *unpoolOutData = unpoolOut.ptr<float>(0, c);
for (int y = 0; y < 64; ++y) {
for (int x = 0; x < 64; ++x) {
int maxIdx = (y * 128 + x) * 2;
std::vector<int> indices{maxIdx + 1, maxIdx + 128, maxIdx + 129};
std::string errMsg = format("Channel %d, y: %d, x: %d", c, y, x);
for (int idx : indices) {
if (poolInpData[idx] > poolInpData[maxIdx]) {
EXPECT_EQ(unpoolOutData[maxIdx], 0.0f) << errMsg;
maxIdx = idx;
}
}
EXPECT_EQ(poolInpData[maxIdx], poolOutData[y * 64 + x]) << errMsg;
EXPECT_EQ(poolIdsData[y * 64 + x], (float)maxIdx) << errMsg;
EXPECT_EQ(unpoolOutData[maxIdx], unpoolInpData[y * 64 + x]) << errMsg;
}
}
}
}
}

View File

@ -135,7 +135,7 @@ video = {
dnn = {'dnn_Net': ['setInput', 'forward', 'setPreferableBackend'],
'': ['readNetFromCaffe', 'readNetFromTensorflow', 'readNetFromTorch', 'readNetFromDarknet',
'readNetFromONNX', 'readNet', 'blobFromImage']}
'readNetFromONNX', 'readNetFromTFLite', 'readNet', 'blobFromImage']}
features2d = {'Feature2D': ['detect', 'compute', 'detectAndCompute', 'descriptorSize', 'descriptorType', 'defaultNorm', 'empty', 'getDefaultName'],
'BRISK': ['create', 'getDefaultName'],