mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Sort text TensorFlow graphs
This commit is contained in:
parent
7b28d5b409
commit
f954f0830c
@ -950,6 +950,7 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
|
|||||||
for (int i = 0; i < net.node_size(); ++i)
|
for (int i = 0; i < net.node_size(); ++i)
|
||||||
{
|
{
|
||||||
const tensorflow::NodeDef& node = net.node(i);
|
const tensorflow::NodeDef& node = net.node(i);
|
||||||
|
int numInputsInGraph = 0;
|
||||||
for (int j = 0; j < node.input_size(); ++j)
|
for (int j = 0; j < node.input_size(); ++j)
|
||||||
{
|
{
|
||||||
std::string inpName = node.input(j);
|
std::string inpName = node.input(j);
|
||||||
@ -957,22 +958,25 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
|
|||||||
inpName = inpName.substr(inpName.find('^') + 1);
|
inpName = inpName.substr(inpName.find('^') + 1);
|
||||||
|
|
||||||
nodesMapIt = nodesMap.find(inpName);
|
nodesMapIt = nodesMap.find(inpName);
|
||||||
CV_Assert(nodesMapIt != nodesMap.end());
|
if (nodesMapIt != nodesMap.end())
|
||||||
edges[nodesMapIt->second].push_back(i);
|
{
|
||||||
|
edges[nodesMapIt->second].push_back(i);
|
||||||
|
numInputsInGraph += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (node.input_size() == 0)
|
if (numInputsInGraph == 0)
|
||||||
nodesToAdd.push_back(i);
|
nodesToAdd.push_back(i);
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if (node.op() == "Merge" || node.op() == "RefMerge")
|
if (node.op() == "Merge" || node.op() == "RefMerge")
|
||||||
{
|
{
|
||||||
int numControlEdges = 0;
|
int numControlEdges = 0;
|
||||||
for (int j = 0; j < node.input_size(); ++j)
|
for (int j = 0; j < numInputsInGraph; ++j)
|
||||||
numControlEdges += node.input(j)[0] == '^';
|
numControlEdges += node.input(j)[0] == '^';
|
||||||
numRefsToAdd[i] = numControlEdges + 1;
|
numRefsToAdd[i] = numControlEdges + 1;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
numRefsToAdd[i] = node.input_size();
|
numRefsToAdd[i] = numInputsInGraph;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -715,6 +715,10 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
simplifySubgraphs(netBin);
|
simplifySubgraphs(netBin);
|
||||||
sortByExecutionOrder(netBin);
|
sortByExecutionOrder(netBin);
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
sortByExecutionOrder(netTxt);
|
||||||
|
}
|
||||||
|
|
||||||
std::set<String> layers_to_ignore;
|
std::set<String> layers_to_ignore;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user