mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #11826 from dkurt:dnn_tf_data_layouts
This commit is contained in:
commit
b80c7bca0d
@ -18,6 +18,7 @@ Implementation of Tensorflow models parser
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <queue>
|
||||
#include "tf_graph_simplifier.hpp"
|
||||
#endif
|
||||
|
||||
@ -558,9 +559,7 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
|
||||
}
|
||||
}
|
||||
|
||||
// If all inputs of specific layer have the same data layout we can say that
|
||||
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
|
||||
static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts)
|
||||
static int getDataLayout(const tensorflow::NodeDef& layer)
|
||||
{
|
||||
if (hasLayerAttr(layer, "data_format"))
|
||||
{
|
||||
@ -572,27 +571,48 @@ static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::
|
||||
else
|
||||
CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
|
||||
}
|
||||
return DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
|
||||
static inline std::string getNodeName(const std::string& tensorName)
|
||||
{
|
||||
return tensorName.substr(0, tensorName.rfind(':'));
|
||||
}
|
||||
|
||||
// If all inputs of specific layer have the same data layout we can say that
|
||||
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
|
||||
static int predictOutputDataLayout(const tensorflow::GraphDef& net,
|
||||
const tensorflow::NodeDef& layer,
|
||||
const std::map<String, int>& data_layouts)
|
||||
{
|
||||
int layout = getDataLayout(layer);
|
||||
if (layout != DATA_LAYOUT_UNKNOWN)
|
||||
return layout;
|
||||
|
||||
// Determine layout by layer's inputs
|
||||
int layout = DATA_LAYOUT_UNKNOWN;
|
||||
std::map<String, int>::const_iterator it;
|
||||
for (int i = 0, n = layer.input_size(); i < n; ++i)
|
||||
{
|
||||
it = data_layouts.find(layer.input(i).substr(0, layer.input(i).rfind(':')));
|
||||
it = data_layouts.find(getNodeName(layer.input(i)));
|
||||
if (it != data_layouts.end())
|
||||
{
|
||||
if (it->second == DATA_LAYOUT_UNKNOWN)
|
||||
return DATA_LAYOUT_UNKNOWN;
|
||||
else if (it->second != layout)
|
||||
if (layout != DATA_LAYOUT_UNKNOWN)
|
||||
{
|
||||
if (layout == DATA_LAYOUT_UNKNOWN)
|
||||
layout = it->second;
|
||||
else
|
||||
if (it->second != layout && it->second != DATA_LAYOUT_UNKNOWN)
|
||||
return DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
else
|
||||
layout = it->second;
|
||||
}
|
||||
}
|
||||
return layout;
|
||||
|
||||
if (layout != DATA_LAYOUT_UNKNOWN)
|
||||
return layout;
|
||||
|
||||
// Determine layout by layer's consumers recursively.
|
||||
it = data_layouts.find(layer.name());
|
||||
CV_Assert(it != data_layouts.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void TFImporter::populateNet(Net dstNet)
|
||||
@ -610,6 +630,52 @@ void TFImporter::populateNet(Net dstNet)
|
||||
int layersSize = net.node_size();
|
||||
|
||||
std::map<String, int> data_layouts;
|
||||
// Pre-fill data layouts where they are set explicitly.
|
||||
// Assuming that nodes are in topological order
|
||||
for (int i = net.node_size() - 1; i >= 0; --i)
|
||||
{
|
||||
const tensorflow::NodeDef& layer = net.node(i);
|
||||
std::string name = layer.name();
|
||||
|
||||
int layout = getDataLayout(layer);
|
||||
std::map<String, int>::iterator it = data_layouts.find(name);
|
||||
if (it != data_layouts.end())
|
||||
{
|
||||
if (layout != DATA_LAYOUT_UNKNOWN)
|
||||
{
|
||||
if (it->second == DATA_LAYOUT_UNKNOWN)
|
||||
it->second = layout;
|
||||
else if (it->second != layout)
|
||||
{
|
||||
it->second = DATA_LAYOUT_UNKNOWN;
|
||||
layout = DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
else
|
||||
layout = it->second;
|
||||
}
|
||||
else
|
||||
data_layouts[name] = layout;
|
||||
|
||||
// Specify input layers to have the same data layout.
|
||||
for (int j = 0; j < layer.input_size(); ++j)
|
||||
{
|
||||
name = getNodeName(layer.input(j));
|
||||
it = data_layouts.find(name);
|
||||
if (it != data_layouts.end())
|
||||
{
|
||||
if (layout != DATA_LAYOUT_UNKNOWN)
|
||||
{
|
||||
if (it->second == DATA_LAYOUT_UNKNOWN)
|
||||
it->second = layout;
|
||||
else if (it->second != layout)
|
||||
it->second = DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
else
|
||||
data_layouts[name] = layout;
|
||||
}
|
||||
}
|
||||
|
||||
// find all Const layers for params
|
||||
std::map<String, int> value_id;
|
||||
@ -628,7 +694,8 @@ void TFImporter::populateNet(Net dstNet)
|
||||
if(layers_to_ignore.find(name) != layers_to_ignore.end())
|
||||
continue;
|
||||
|
||||
data_layouts[name] = predictOutputDataLayout(layer, data_layouts);
|
||||
int predictedLayout = predictOutputDataLayout(net, layer, data_layouts);
|
||||
data_layouts[name] = predictedLayout;
|
||||
|
||||
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
|
||||
{
|
||||
@ -885,6 +952,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
// one input only
|
||||
connect(layer_id, dstNet, inpId, id, 0);
|
||||
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
else if (type == "Flatten" || type == "Squeeze")
|
||||
{
|
||||
@ -1013,7 +1081,10 @@ void TFImporter::populateNet(Net dstNet)
|
||||
{
|
||||
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
|
||||
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
|
||||
layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW(axis) : axis);
|
||||
|
||||
if (data_layouts[name] == DATA_LAYOUT_NHWC)
|
||||
axis = toNCHW(axis);
|
||||
layerParams.set("axis", axis);
|
||||
|
||||
int id = dstNet.addLayer(name, "Concat", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
@ -142,9 +142,10 @@ TEST_P(Test_TensorFlow_layers, eltwise_add_mul)
|
||||
runTensorFlowNet("eltwise_add_mul", GetParam());
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, pad_and_concat)
|
||||
TEST_P(Test_TensorFlow_layers, concat)
|
||||
{
|
||||
runTensorFlowNet("pad_and_concat", GetParam());
|
||||
runTensorFlowNet("concat_axis_1", GetParam());
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, batch_norm)
|
||||
|
Loading…
Reference in New Issue
Block a user