mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 19:50:38 +08:00
Fuse batch normalization and flatten TensorFlow subgraphs in runtime
This commit is contained in:
parent
5b868ccd82
commit
9457bf10ab
434
modules/dnn/src/tensorflow/tf_graph_editor.cpp
Normal file
434
modules/dnn/src/tensorflow/tf_graph_editor.cpp
Normal file
@ -0,0 +1,434 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
// Copyright (C) 2018, Intel Corporation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
|
||||
#ifdef HAVE_PROTOBUF
|
||||
|
||||
#include "tf_graph_editor.hpp"
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
|
||||
using ::google::protobuf::RepeatedField;
|
||||
using ::google::protobuf::MapPair;
|
||||
|
||||
class Subgraph // Interface to match and replace TensorFlow subgraphs.
|
||||
{
|
||||
public:
|
||||
// Add a node to be matched in the origin graph. Specify ids of nodes that
|
||||
// are expected to be inputs. Returns id of a newly added node.
|
||||
// TODO: Replace inputs to std::vector<int> in C++11
|
||||
int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1,
|
||||
int input_2 = -1, int input_3 = -1)
|
||||
{
|
||||
int nodeInputs[] = {input_0, input_1, input_2, input_3};
|
||||
int numInputs = 0;
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
CV_Assert(nodeInputs[i] < (int)nodes.size());
|
||||
numInputs += (int)(nodeInputs[i] != -1);
|
||||
}
|
||||
nodes.push_back(op);
|
||||
inputs.push_back(std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
|
||||
return nodes.size() - 1;
|
||||
}
|
||||
|
||||
// Specify resulting node. All the matched nodes in subgraph excluding
|
||||
// input nodes will be fused into this single node.
|
||||
// TODO: Replace inputs to std::vector<int> in C++11
|
||||
void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1,
|
||||
int input_2 = -1, int input_3 = -1, int input_4 = -1,
|
||||
int input_5 = -1)
|
||||
{
|
||||
int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5};
|
||||
int numInputs = 0;
|
||||
for (int i = 0; i < 6; ++i)
|
||||
{
|
||||
CV_Assert(nodeInputs[i] < (int)nodes.size());
|
||||
numInputs += (int)(nodeInputs[i] != -1);
|
||||
}
|
||||
fusedNodeInputs = std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs);
|
||||
|
||||
fusedNodeOp = op;
|
||||
nodesToFuse.clear();
|
||||
for (int i = 0; i < nodes.size(); ++i)
|
||||
{
|
||||
if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end())
|
||||
nodesToFuse.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
static const tensorflow::NodeDef& getInputNode(const tensorflow::GraphDef& net,
|
||||
const tensorflow::NodeDef& node,
|
||||
int inpId)
|
||||
{
|
||||
CV_Assert(inpId < node.input_size());
|
||||
std::string name = node.input(inpId);
|
||||
const int numNodes = net.node_size();
|
||||
for (int i = 0; i < numNodes; ++i)
|
||||
{
|
||||
const tensorflow::NodeDef& node = net.node(i);
|
||||
if (node.name() == name)
|
||||
return node;
|
||||
}
|
||||
CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
|
||||
return net.node(0); // just return something
|
||||
}
|
||||
|
||||
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
|
||||
// Returns true if nodes are matched and can be fused.
|
||||
bool match(const tensorflow::GraphDef& net, int nodeId, int* numMatchedNodes)
|
||||
{
|
||||
*numMatchedNodes = 0;
|
||||
int numNodes = net.node_size();
|
||||
for (int i = 0; i < nodesToFuse.size(); ++i)
|
||||
{
|
||||
if (nodeId + i > numNodes - 1)
|
||||
return false;
|
||||
|
||||
const tensorflow::NodeDef &node = net.node(nodeId + i);
|
||||
if (node.op() != nodes[nodesToFuse[i]])
|
||||
return false;
|
||||
|
||||
std::vector<int>& inputNodes = inputs[nodesToFuse[i]];
|
||||
if (inputNodes.size() != node.input_size())
|
||||
return false;
|
||||
for (int j = 0; j < inputNodes.size(); ++j)
|
||||
{
|
||||
if (nodes[inputNodes[j]].empty()) // Unknown input node type.
|
||||
continue;
|
||||
const tensorflow::NodeDef& inpNode = getInputNode(net, node, j);
|
||||
if (inpNode.op() != nodes[inputNodes[j]])
|
||||
return false;
|
||||
}
|
||||
|
||||
*numMatchedNodes += 1;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Fuse matched subgraph.
|
||||
void replace(tensorflow::GraphDef& net, int nodeId, int* numReplacedNodes)
|
||||
{
|
||||
*numReplacedNodes = 0;
|
||||
|
||||
// Extract names of input nodes.
|
||||
std::vector<std::string> inputsNames(fusedNodeInputs.size());
|
||||
for (int i = 0; i < fusedNodeInputs.size(); ++i)
|
||||
{
|
||||
std::string inpName;
|
||||
// Find input node name looking at inputs of fused nodes.
|
||||
for (int j = 0; j < nodesToFuse.size() && inpName.empty(); ++j)
|
||||
{
|
||||
const tensorflow::NodeDef &node = net.node(nodeId + j);
|
||||
std::vector<int>& inpIndices = inputs[nodesToFuse[j]];
|
||||
|
||||
CV_Assert(node.input_size() == inpIndices.size());
|
||||
for (int k = 0; k < inpIndices.size(); ++k)
|
||||
{
|
||||
if (inpIndices[k] == fusedNodeInputs[i])
|
||||
{
|
||||
inpName = node.input(k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
CV_Assert(!inpName.empty());
|
||||
inputsNames[i] = inpName;
|
||||
}
|
||||
|
||||
// Remove all nodes except the last one.
|
||||
*numReplacedNodes = nodesToFuse.size() - 1;
|
||||
net.mutable_node()->DeleteSubrange(nodeId, *numReplacedNodes);
|
||||
|
||||
// Modify the last node to be a fused one.
|
||||
tensorflow::NodeDef* node = net.mutable_node(nodeId);
|
||||
node->set_op(fusedNodeOp);
|
||||
node->clear_input();
|
||||
for (int i = 0; i < inputsNames.size(); ++i)
|
||||
{
|
||||
node->add_input(inputsNames[i]);
|
||||
}
|
||||
|
||||
std::vector<tensorflow::NodeDef> inputNodes(inputsNames.size());
|
||||
for (int i = 0; i < inputsNames.size(); ++i)
|
||||
{
|
||||
inputNodes[i] = getInputNode(net, *node, i);
|
||||
}
|
||||
finalize(net, node, inputNodes);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef*,
|
||||
const std::vector<tensorflow::NodeDef>&) {}
|
||||
|
||||
private:
|
||||
std::vector<std::string> nodes; // Nodes to be matched in the origin graph.
|
||||
std::vector<std::vector<int> > inputs; // Connections of an every node to it's inputs.
|
||||
|
||||
std::string fusedNodeOp; // Operation name of resulting fused node.
|
||||
std::vector<int> nodesToFuse; // Set of nodes to be fused.
|
||||
std::vector<int> fusedNodeInputs; // Inputs of fused node.
|
||||
};
|
||||
|
||||
class BatchNormSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
BatchNormSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int epsilon = addNodeToMatch("Const");
|
||||
int moving_variance = addNodeToMatch("Const");
|
||||
int moving_mean = addNodeToMatch("Const");
|
||||
int beta = addNodeToMatch("Const");
|
||||
int gamma = addNodeToMatch("Const");
|
||||
int add = addNodeToMatch("Add", moving_variance, epsilon);
|
||||
int rsqrt = addNodeToMatch("Rsqrt", add);
|
||||
int mul = addNodeToMatch("Mul", rsqrt, gamma);
|
||||
int mul_1 = addNodeToMatch("Mul", input, mul);
|
||||
int mul_2 = addNodeToMatch("Mul", moving_mean, mul);
|
||||
int sub = addNodeToMatch("Sub", beta, mul_2);
|
||||
addNodeToMatch("Add", mul_1, sub);
|
||||
|
||||
setFusedNode("FusedBatchNorm", input, gamma, beta, moving_mean, moving_variance, epsilon);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
|
||||
const std::vector<tensorflow::NodeDef>& inputNodes)
|
||||
{
|
||||
Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor());
|
||||
CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1);
|
||||
|
||||
fusedNode->mutable_input()->ReleaseLast();
|
||||
fusedNode->clear_attr();
|
||||
tensorflow::AttrValue epsilon;
|
||||
epsilon.set_f(epsMat.at<float>(0));
|
||||
fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
|
||||
}
|
||||
};
|
||||
|
||||
class BatchNormNoGammaSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
BatchNormNoGammaSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int epsilon = addNodeToMatch("Const");
|
||||
int moving_variance = addNodeToMatch("Const");
|
||||
int moving_mean = addNodeToMatch("Const");
|
||||
int beta = addNodeToMatch("Const");
|
||||
int add = addNodeToMatch("Add", moving_variance, epsilon);
|
||||
int rsqrt = addNodeToMatch("Rsqrt", add);
|
||||
int mul = addNodeToMatch("Mul", input, rsqrt);
|
||||
int mul_1 = addNodeToMatch("Mul", moving_mean, rsqrt);
|
||||
int sub = addNodeToMatch("Sub", beta, mul_1);
|
||||
addNodeToMatch("Add", mul, sub);
|
||||
|
||||
// There is a fake reference to beta that will be replaced to a new gamma tensor.
|
||||
setFusedNode("FusedBatchNorm", input, beta, beta, moving_mean, moving_variance, epsilon);
|
||||
}
|
||||
|
||||
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
|
||||
const std::vector<tensorflow::NodeDef>& inputNodes)
|
||||
{
|
||||
Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor());
|
||||
CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1);
|
||||
|
||||
fusedNode->mutable_input()->ReleaseLast();
|
||||
fusedNode->clear_attr();
|
||||
tensorflow::AttrValue epsilon;
|
||||
epsilon.set_f(epsMat.at<float>(0));
|
||||
fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
|
||||
|
||||
tensorflow::NodeDef* gamma = net.add_node();
|
||||
gamma->set_op("Const");
|
||||
gamma->set_name(fusedNode->name() + "/gamma");
|
||||
// Just put a single value to recognize this node as Const.
|
||||
gamma->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", epsilon));
|
||||
fusedNode->set_input(1, gamma->name());
|
||||
}
|
||||
};
|
||||
|
||||
// tf.contrib.layers.flatten
|
||||
class FlattenSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
FlattenSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shape = addNodeToMatch("Const");
|
||||
int stack = addNodeToMatch("Const");
|
||||
int stack_1 = addNodeToMatch("Const");
|
||||
int stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
int shape_pack = addNodeToMatch("Const");
|
||||
int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
|
||||
addNodeToMatch("Reshape", input, pack);
|
||||
|
||||
setFusedNode("Flatten", input);
|
||||
}
|
||||
};
|
||||
|
||||
// tf.contrib.layers.flatten in case of unknown batch size
|
||||
class FlattenShapeSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
FlattenShapeSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int shape = addNodeToMatch("Shape", input);
|
||||
int stack = addNodeToMatch("Const");
|
||||
int stack_1 = addNodeToMatch("Const");
|
||||
int stack_2 = addNodeToMatch("Const");
|
||||
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
|
||||
int shape_pack = addNodeToMatch("Const");
|
||||
int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
|
||||
addNodeToMatch("Reshape", input, pack);
|
||||
|
||||
setFusedNode("Flatten", input);
|
||||
}
|
||||
};
|
||||
|
||||
void simplifySubgraphs(tensorflow::GraphDef& net)
|
||||
{
|
||||
std::vector<Ptr<Subgraph> > subgraphs;
|
||||
subgraphs.push_back(Ptr<Subgraph>(new BatchNormSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new BatchNormNoGammaSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new FlattenSubgraph()));
|
||||
subgraphs.push_back(Ptr<Subgraph>(new FlattenShapeSubgraph()));
|
||||
|
||||
int numNodes = net.node_size();
|
||||
int numMatchedNodes, numReplacedNodes;
|
||||
for (int i = 0; i < numNodes; ++i)
|
||||
{
|
||||
for (int j = 0; j < subgraphs.size(); ++j)
|
||||
{
|
||||
if (subgraphs[j]->match(net, i, &numMatchedNodes))
|
||||
{
|
||||
subgraphs[j]->replace(net, i, &numReplacedNodes);
|
||||
numNodes -= numReplacedNodes;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveIdentityOps(tensorflow::GraphDef& net)
|
||||
{
|
||||
typedef std::map<String, String> IdentityOpsMap;
|
||||
IdentityOpsMap identity_ops;
|
||||
|
||||
std::vector<int> identity_ops_idx;
|
||||
|
||||
int layersCount = net.node_size();
|
||||
for (int li = 0; li < layersCount; li++)
|
||||
{
|
||||
const tensorflow::NodeDef &layer = net.node(li);
|
||||
String type = layer.op();
|
||||
|
||||
if (type == "Identity" || type == "Dropout") {
|
||||
identity_ops_idx.push_back(li);
|
||||
identity_ops[layer.name()] = layer.input(0);
|
||||
}
|
||||
}
|
||||
|
||||
for (int li = 0; li < layersCount; li++)
|
||||
{
|
||||
tensorflow::NodeDef* layer = net.mutable_node(li);
|
||||
for (int input_id = 0; input_id < layer->input_size(); input_id++) {
|
||||
String input_op_name = layer->input(input_id);
|
||||
IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
|
||||
|
||||
if (it != identity_ops.end()) {
|
||||
layer->set_input(input_id, it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(identity_ops_idx.begin(), identity_ops_idx.end());
|
||||
|
||||
int removed_nodes = 0;
|
||||
for(size_t i = 0; i < identity_ops_idx.size(); i++) {
|
||||
int start_id = identity_ops_idx[i] - removed_nodes;
|
||||
net.mutable_node()->DeleteSubrange(start_id, 1);
|
||||
removed_nodes++;
|
||||
}
|
||||
}
|
||||
|
||||
Mat getTensorContent(const tensorflow::TensorProto &tensor)
|
||||
{
|
||||
std::string content = tensor.tensor_content();
|
||||
switch (tensor.dtype())
|
||||
{
|
||||
case tensorflow::DT_FLOAT:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
|
||||
else
|
||||
{
|
||||
const RepeatedField<float>& field = tensor.float_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
|
||||
}
|
||||
}
|
||||
case tensorflow::DT_DOUBLE:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
|
||||
else
|
||||
{
|
||||
const RepeatedField<double>& field = tensor.double_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
|
||||
}
|
||||
}
|
||||
case tensorflow::DT_INT32:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
|
||||
else
|
||||
{
|
||||
const RepeatedField<int32_t>& field = tensor.int_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
|
||||
}
|
||||
}
|
||||
case tensorflow::DT_HALF:
|
||||
{
|
||||
Mat halfs;
|
||||
if (!content.empty())
|
||||
{
|
||||
static const int kHalfSize = 2;
|
||||
halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
const RepeatedField<int32_t>& field = tensor.half_val();
|
||||
CV_Assert(!field.empty());
|
||||
Mat ints(1, field.size(), CV_32SC1, (void*)field.data());
|
||||
ints.convertTo(halfs, CV_16UC1);
|
||||
}
|
||||
// Reinterpret as a signed shorts just for a convertFp16 call.
|
||||
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
|
||||
Mat floats(halfs.size(), CV_32FC1);
|
||||
convertFp16(halfsSigned, floats);
|
||||
return floats;
|
||||
}
|
||||
case tensorflow::DT_QUINT8:
|
||||
{
|
||||
CV_Assert(!content.empty());
|
||||
return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
|
||||
}
|
||||
default:
|
||||
CV_Error(Error::StsError, "Tensor's data type is not supported");
|
||||
break;
|
||||
}
|
||||
return Mat();
|
||||
}
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace dnn, namespace cv
|
||||
|
||||
#endif // HAVE_PROTOBUF
|
30
modules/dnn/src/tensorflow/tf_graph_editor.hpp
Normal file
30
modules/dnn/src/tensorflow/tf_graph_editor.hpp
Normal file
@ -0,0 +1,30 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
// Copyright (C) 2018, Intel Corporation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
|
||||
#ifndef __OPENCV_DNN_TF_SIMPLIFIER_HPP__
|
||||
#define __OPENCV_DNN_TF_SIMPLIFIER_HPP__
|
||||
|
||||
#include "../precomp.hpp"
|
||||
|
||||
#ifdef HAVE_PROTOBUF
|
||||
|
||||
#include "tf_io.hpp"
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
|
||||
void RemoveIdentityOps(tensorflow::GraphDef& net);
|
||||
|
||||
void simplifySubgraphs(tensorflow::GraphDef& net);
|
||||
|
||||
Mat getTensorContent(const tensorflow::TensorProto &tensor);
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace dnn, namespace cv
|
||||
|
||||
#endif // HAVE_PROTOBUF
|
||||
#endif // __OPENCV_DNN_TF_SIMPLIFIER_HPP__
|
@ -22,6 +22,7 @@ Implementation of Tensorflow models parser
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#include "tf_io.hpp"
|
||||
#include "tf_graph_editor.hpp"
|
||||
#endif
|
||||
|
||||
namespace cv {
|
||||
@ -87,77 +88,6 @@ void blobShapeFromTensor(const tensorflow::TensorProto &tensor, MatShape& shape)
|
||||
}
|
||||
}
|
||||
|
||||
static Mat getTensorContent(const tensorflow::TensorProto &tensor)
|
||||
{
|
||||
std::string content = tensor.tensor_content();
|
||||
switch (tensor.dtype())
|
||||
{
|
||||
case tensorflow::DT_FLOAT:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
|
||||
else
|
||||
{
|
||||
const RepeatedField<float>& field = tensor.float_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
|
||||
}
|
||||
}
|
||||
case tensorflow::DT_DOUBLE:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
|
||||
else
|
||||
{
|
||||
const RepeatedField<double>& field = tensor.double_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
|
||||
}
|
||||
}
|
||||
case tensorflow::DT_INT32:
|
||||
{
|
||||
if (!content.empty())
|
||||
return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
|
||||
else
|
||||
{
|
||||
const RepeatedField<int32_t>& field = tensor.int_val();
|
||||
CV_Assert(!field.empty());
|
||||
return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
|
||||
}
|
||||
}
|
||||
case tensorflow::DT_HALF:
|
||||
{
|
||||
Mat halfs;
|
||||
if (!content.empty())
|
||||
{
|
||||
static const int kHalfSize = 2;
|
||||
halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
const RepeatedField<int32_t>& field = tensor.half_val();
|
||||
CV_Assert(!field.empty());
|
||||
Mat ints(1, field.size(), CV_32SC1, (void*)field.data());
|
||||
ints.convertTo(halfs, CV_16UC1);
|
||||
}
|
||||
// Reinterpret as a signed shorts just for a convertFp16 call.
|
||||
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
|
||||
Mat floats(halfs.size(), CV_32FC1);
|
||||
convertFp16(halfsSigned, floats);
|
||||
return floats;
|
||||
}
|
||||
case tensorflow::DT_QUINT8:
|
||||
{
|
||||
CV_Assert(!content.empty());
|
||||
return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
|
||||
}
|
||||
default:
|
||||
CV_Error(Error::StsError, "Tensor's data type is not supported");
|
||||
break;
|
||||
}
|
||||
return Mat();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
|
||||
{
|
||||
@ -364,47 +294,6 @@ void setPadding(LayerParams &layerParams, const tensorflow::NodeDef &layer)
|
||||
layerParams.set("pad_mode", getLayerAttr(layer, "padding").s());
|
||||
}
|
||||
|
||||
void RemoveIdentityOps(tensorflow::GraphDef& net) {
|
||||
typedef std::map<String, String> IdentityOpsMap;
|
||||
IdentityOpsMap identity_ops;
|
||||
|
||||
std::vector<int> identity_ops_idx;
|
||||
|
||||
int layersCount = net.node_size();
|
||||
for (int li = 0; li < layersCount; li++)
|
||||
{
|
||||
const tensorflow::NodeDef &layer = net.node(li);
|
||||
String type = layer.op();
|
||||
|
||||
if (type == "Identity" || type == "Dropout") {
|
||||
identity_ops_idx.push_back(li);
|
||||
identity_ops[layer.name()] = layer.input(0);
|
||||
}
|
||||
}
|
||||
|
||||
for (int li = 0; li < layersCount; li++)
|
||||
{
|
||||
tensorflow::NodeDef* layer = net.mutable_node(li);
|
||||
for (int input_id = 0; input_id < layer->input_size(); input_id++) {
|
||||
String input_op_name = layer->input(input_id);
|
||||
IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
|
||||
|
||||
if (it != identity_ops.end()) {
|
||||
layer->set_input(input_id, it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(identity_ops_idx.begin(), identity_ops_idx.end());
|
||||
|
||||
int removed_nodes = 0;
|
||||
for(size_t i = 0; i < identity_ops_idx.size(); i++) {
|
||||
int start_id = identity_ops_idx[i] - removed_nodes;
|
||||
net.mutable_node()->DeleteSubrange(start_id, 1);
|
||||
removed_nodes++;
|
||||
}
|
||||
}
|
||||
|
||||
Pin parsePin(const std::string &name)
|
||||
{
|
||||
Pin pin(name);
|
||||
@ -697,6 +586,9 @@ void TFImporter::populateNet(Net dstNet)
|
||||
RemoveIdentityOps(netBin);
|
||||
RemoveIdentityOps(netTxt);
|
||||
|
||||
if (!netTxt.ByteSize())
|
||||
simplifySubgraphs(netBin);
|
||||
|
||||
std::set<String> layers_to_ignore;
|
||||
|
||||
tensorflow::GraphDef& net = netTxt.ByteSize() != 0 ? netTxt : netBin;
|
||||
@ -936,10 +828,28 @@ void TFImporter::populateNet(Net dstNet)
|
||||
connect(layer_id, dstNet, inpId, id, 0);
|
||||
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
|
||||
}
|
||||
else if (type == "Flatten")
|
||||
else if (type == "Flatten" || type == "Squeeze")
|
||||
{
|
||||
Pin inpId = parsePin(layer.input(0));
|
||||
if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
|
||||
int inpLayout = data_layouts[layer.input(0)];
|
||||
if (type == "Squeeze")
|
||||
{
|
||||
CV_Assert(hasLayerAttr(layer, "squeeze_dims"));
|
||||
const tensorflow::AttrValue& dims = getLayerAttr(layer, "squeeze_dims");
|
||||
if (inpLayout == DATA_LAYOUT_NHWC)
|
||||
{
|
||||
if (dims.list().i_size() != 2 || dims.list().i(0) != 1 || dims.list().i(1) != 2)
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration");
|
||||
}
|
||||
else if (inpLayout == DATA_LAYOUT_NCHW)
|
||||
{
|
||||
if (dims.list().i_size() != 2 || dims.list().i(0) != 2 || dims.list().i(1) != 3)
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration");
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration");
|
||||
}
|
||||
if (inpLayout == DATA_LAYOUT_NHWC)
|
||||
{
|
||||
LayerParams permLP;
|
||||
int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
|
||||
@ -1274,14 +1184,36 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
bool isTraining = hasLayerAttr(layer, "is_training") && getLayerAttr(layer, "is_training").b();
|
||||
|
||||
layerParams.blobs.resize(4);
|
||||
Mat gamma, beta, mean, std;
|
||||
blobFromTensor(getConstBlob(layer, value_id, 1), gamma);
|
||||
blobFromTensor(getConstBlob(layer, value_id, 2), beta);
|
||||
layerParams.blobs.resize(2);
|
||||
|
||||
const tensorflow::TensorProto& gammaTensor = getConstBlob(layer, value_id, 1);
|
||||
if (!gammaTensor.tensor_content().empty())
|
||||
{
|
||||
layerParams.blobs.resize(layerParams.blobs.size() + 1);
|
||||
layerParams.set("has_weight", true);
|
||||
blobFromTensor(gammaTensor, layerParams.blobs.back());
|
||||
}
|
||||
else
|
||||
layerParams.set("has_weight", false);
|
||||
|
||||
const tensorflow::TensorProto& betaTensor = getConstBlob(layer, value_id, 2);
|
||||
if (!betaTensor.tensor_content().empty())
|
||||
{
|
||||
layerParams.blobs.resize(layerParams.blobs.size() + 1);
|
||||
layerParams.set("has_bias", true);
|
||||
blobFromTensor(betaTensor, layerParams.blobs.back());
|
||||
}
|
||||
else
|
||||
layerParams.set("has_bias", false);
|
||||
|
||||
Mat mean, std;
|
||||
if (isTraining)
|
||||
{
|
||||
mean = Mat::zeros(1, beta.total(), CV_32F);
|
||||
std = Mat::ones(1, beta.total(), CV_32F);
|
||||
if (layerParams.blobs.size() == 2)
|
||||
CV_Error(Error::StsNotImplemented, "Cannot determine number "
|
||||
"of parameters for batch normalization layer.");
|
||||
mean = Mat::zeros(1, layerParams.blobs[3].total(), CV_32F);
|
||||
std = Mat::ones(1, layerParams.blobs[3].total(), CV_32F);
|
||||
|
||||
// Add an extra layer: Mean-Variance normalization
|
||||
LayerParams mvnParams;
|
||||
@ -1299,15 +1231,10 @@ void TFImporter::populateNet(Net dstNet)
|
||||
}
|
||||
layerParams.blobs[0] = mean;
|
||||
layerParams.blobs[1] = std;
|
||||
layerParams.blobs[2] = gamma;
|
||||
layerParams.blobs[3] = beta;
|
||||
|
||||
if (hasLayerAttr(layer, "epsilon"))
|
||||
layerParams.set("eps", getLayerAttr(layer, "epsilon").f());
|
||||
|
||||
layerParams.set("has_weight", true);
|
||||
layerParams.set("has_bias", true);
|
||||
|
||||
int id = dstNet.addLayer(name, "BatchNorm", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
|
@ -150,6 +150,9 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
|
||||
runTensorFlowNet("batch_norm_text", targetId, true);
|
||||
runTensorFlowNet("mvn_batch_norm", targetId);
|
||||
runTensorFlowNet("mvn_batch_norm_1x1", targetId);
|
||||
runTensorFlowNet("unfused_batch_norm", targetId);
|
||||
runTensorFlowNet("fused_batch_norm_no_gamma", targetId);
|
||||
runTensorFlowNet("unfused_batch_norm_no_gamma", targetId);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, pooling)
|
||||
@ -185,6 +188,8 @@ TEST_P(Test_TensorFlow_layers, reshape)
|
||||
runTensorFlowNet("shift_reshape_no_reorder", targetId);
|
||||
runTensorFlowNet("reshape_reduce", targetId);
|
||||
runTensorFlowNet("flatten", targetId, true);
|
||||
runTensorFlowNet("unfused_flatten", targetId);
|
||||
runTensorFlowNet("unfused_flatten_unknown_batch", targetId);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_TensorFlow_layers, availableDnnTargets());
|
||||
|
Loading…
Reference in New Issue
Block a user