diff --git a/modules/dnn/src/tensorflow/tf_graph_editor.cpp b/modules/dnn/src/tensorflow/tf_graph_editor.cpp new file mode 100644 index 0000000000..6e841f2068 --- /dev/null +++ b/modules/dnn/src/tensorflow/tf_graph_editor.cpp @@ -0,0 +1,434 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2018, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#ifdef HAVE_PROTOBUF + +#include "tf_graph_editor.hpp" + +namespace cv { namespace dnn { +CV__DNN_EXPERIMENTAL_NS_BEGIN + +using ::google::protobuf::RepeatedField; +using ::google::protobuf::MapPair; + +class Subgraph // Interface to match and replace TensorFlow subgraphs. +{ +public: + // Add a node to be matched in the origin graph. Specify ids of nodes that + // are expected to be inputs. Returns id of a newly added node. + // TODO: Replace inputs to std::vector in C++11 + int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1, + int input_2 = -1, int input_3 = -1) + { + int nodeInputs[] = {input_0, input_1, input_2, input_3}; + int numInputs = 0; + for (int i = 0; i < 4; ++i) + { + CV_Assert(nodeInputs[i] < (int)nodes.size()); + numInputs += (int)(nodeInputs[i] != -1); + } + nodes.push_back(op); + inputs.push_back(std::vector(&nodeInputs[0], &nodeInputs[0] + numInputs)); + return nodes.size() - 1; + } + + // Specify resulting node. All the matched nodes in subgraph excluding + // input nodes will be fused into this single node. + // TODO: Replace inputs to std::vector in C++11 + void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1, + int input_2 = -1, int input_3 = -1, int input_4 = -1, + int input_5 = -1) + { + int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5}; + int numInputs = 0; + for (int i = 0; i < 6; ++i) + { + CV_Assert(nodeInputs[i] < (int)nodes.size()); + numInputs += (int)(nodeInputs[i] != -1); + } + fusedNodeInputs = std::vector(&nodeInputs[0], &nodeInputs[0] + numInputs); + + fusedNodeOp = op; + nodesToFuse.clear(); + for (int i = 0; i < nodes.size(); ++i) + { + if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end()) + nodesToFuse.push_back(i); + } + } + + static const tensorflow::NodeDef& getInputNode(const tensorflow::GraphDef& net, + const tensorflow::NodeDef& node, + int inpId) + { + CV_Assert(inpId < node.input_size()); + std::string name = node.input(inpId); + const int numNodes = net.node_size(); + for (int i = 0; i < numNodes; ++i) + { + const tensorflow::NodeDef& node = net.node(i); + if (node.name() == name) + return node; + } + CV_Error(Error::StsParseError, "Input node with name " + name + " not found"); + return net.node(0); // just return something + } + + // Match TensorFlow subgraph starting from with a set of nodes to be fused. + // Returns true if nodes are matched and can be fused. + bool match(const tensorflow::GraphDef& net, int nodeId, int* numMatchedNodes) + { + *numMatchedNodes = 0; + int numNodes = net.node_size(); + for (int i = 0; i < nodesToFuse.size(); ++i) + { + if (nodeId + i > numNodes - 1) + return false; + + const tensorflow::NodeDef &node = net.node(nodeId + i); + if (node.op() != nodes[nodesToFuse[i]]) + return false; + + std::vector& inputNodes = inputs[nodesToFuse[i]]; + if (inputNodes.size() != node.input_size()) + return false; + for (int j = 0; j < inputNodes.size(); ++j) + { + if (nodes[inputNodes[j]].empty()) // Unknown input node type. + continue; + const tensorflow::NodeDef& inpNode = getInputNode(net, node, j); + if (inpNode.op() != nodes[inputNodes[j]]) + return false; + } + + *numMatchedNodes += 1; + } + return true; + } + + // Fuse matched subgraph. + void replace(tensorflow::GraphDef& net, int nodeId, int* numReplacedNodes) + { + *numReplacedNodes = 0; + + // Extract names of input nodes. + std::vector inputsNames(fusedNodeInputs.size()); + for (int i = 0; i < fusedNodeInputs.size(); ++i) + { + std::string inpName; + // Find input node name looking at inputs of fused nodes. + for (int j = 0; j < nodesToFuse.size() && inpName.empty(); ++j) + { + const tensorflow::NodeDef &node = net.node(nodeId + j); + std::vector& inpIndices = inputs[nodesToFuse[j]]; + + CV_Assert(node.input_size() == inpIndices.size()); + for (int k = 0; k < inpIndices.size(); ++k) + { + if (inpIndices[k] == fusedNodeInputs[i]) + { + inpName = node.input(k); + break; + } + } + } + CV_Assert(!inpName.empty()); + inputsNames[i] = inpName; + } + + // Remove all nodes except the last one. + *numReplacedNodes = nodesToFuse.size() - 1; + net.mutable_node()->DeleteSubrange(nodeId, *numReplacedNodes); + + // Modify the last node to be a fused one. + tensorflow::NodeDef* node = net.mutable_node(nodeId); + node->set_op(fusedNodeOp); + node->clear_input(); + for (int i = 0; i < inputsNames.size(); ++i) + { + node->add_input(inputsNames[i]); + } + + std::vector inputNodes(inputsNames.size()); + for (int i = 0; i < inputsNames.size(); ++i) + { + inputNodes[i] = getInputNode(net, *node, i); + } + finalize(net, node, inputNodes); + } + + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef*, + const std::vector&) {} + +private: + std::vector nodes; // Nodes to be matched in the origin graph. + std::vector > inputs; // Connections of an every node to it's inputs. + + std::string fusedNodeOp; // Operation name of resulting fused node. + std::vector nodesToFuse; // Set of nodes to be fused. + std::vector fusedNodeInputs; // Inputs of fused node. +}; + +class BatchNormSubgraph : public Subgraph +{ +public: + BatchNormSubgraph() + { + int input = addNodeToMatch(""); + int epsilon = addNodeToMatch("Const"); + int moving_variance = addNodeToMatch("Const"); + int moving_mean = addNodeToMatch("Const"); + int beta = addNodeToMatch("Const"); + int gamma = addNodeToMatch("Const"); + int add = addNodeToMatch("Add", moving_variance, epsilon); + int rsqrt = addNodeToMatch("Rsqrt", add); + int mul = addNodeToMatch("Mul", rsqrt, gamma); + int mul_1 = addNodeToMatch("Mul", input, mul); + int mul_2 = addNodeToMatch("Mul", moving_mean, mul); + int sub = addNodeToMatch("Sub", beta, mul_2); + addNodeToMatch("Add", mul_1, sub); + + setFusedNode("FusedBatchNorm", input, gamma, beta, moving_mean, moving_variance, epsilon); + } + + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, + const std::vector& inputNodes) + { + Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor()); + CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1); + + fusedNode->mutable_input()->ReleaseLast(); + fusedNode->clear_attr(); + tensorflow::AttrValue epsilon; + epsilon.set_f(epsMat.at(0)); + fusedNode->mutable_attr()->insert(MapPair("epsilon", epsilon)); + } +}; + +class BatchNormNoGammaSubgraph : public Subgraph +{ +public: + BatchNormNoGammaSubgraph() + { + int input = addNodeToMatch(""); + int epsilon = addNodeToMatch("Const"); + int moving_variance = addNodeToMatch("Const"); + int moving_mean = addNodeToMatch("Const"); + int beta = addNodeToMatch("Const"); + int add = addNodeToMatch("Add", moving_variance, epsilon); + int rsqrt = addNodeToMatch("Rsqrt", add); + int mul = addNodeToMatch("Mul", input, rsqrt); + int mul_1 = addNodeToMatch("Mul", moving_mean, rsqrt); + int sub = addNodeToMatch("Sub", beta, mul_1); + addNodeToMatch("Add", mul, sub); + + // There is a fake reference to beta that will be replaced to a new gamma tensor. + setFusedNode("FusedBatchNorm", input, beta, beta, moving_mean, moving_variance, epsilon); + } + + virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode, + const std::vector& inputNodes) + { + Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor()); + CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1); + + fusedNode->mutable_input()->ReleaseLast(); + fusedNode->clear_attr(); + tensorflow::AttrValue epsilon; + epsilon.set_f(epsMat.at(0)); + fusedNode->mutable_attr()->insert(MapPair("epsilon", epsilon)); + + tensorflow::NodeDef* gamma = net.add_node(); + gamma->set_op("Const"); + gamma->set_name(fusedNode->name() + "/gamma"); + // Just put a single value to recognize this node as Const. + gamma->mutable_attr()->insert(MapPair("value", epsilon)); + fusedNode->set_input(1, gamma->name()); + } +}; + +// tf.contrib.layers.flatten +class FlattenSubgraph : public Subgraph +{ +public: + FlattenSubgraph() + { + int input = addNodeToMatch(""); + int shape = addNodeToMatch("Const"); + int stack = addNodeToMatch("Const"); + int stack_1 = addNodeToMatch("Const"); + int stack_2 = addNodeToMatch("Const"); + int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + int shape_pack = addNodeToMatch("Const"); + int pack = addNodeToMatch("Pack", strided_slice, shape_pack); + addNodeToMatch("Reshape", input, pack); + + setFusedNode("Flatten", input); + } +}; + +// tf.contrib.layers.flatten in case of unknown batch size +class FlattenShapeSubgraph : public Subgraph +{ +public: + FlattenShapeSubgraph() + { + int input = addNodeToMatch(""); + int shape = addNodeToMatch("Shape", input); + int stack = addNodeToMatch("Const"); + int stack_1 = addNodeToMatch("Const"); + int stack_2 = addNodeToMatch("Const"); + int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + int shape_pack = addNodeToMatch("Const"); + int pack = addNodeToMatch("Pack", strided_slice, shape_pack); + addNodeToMatch("Reshape", input, pack); + + setFusedNode("Flatten", input); + } +}; + +void simplifySubgraphs(tensorflow::GraphDef& net) +{ + std::vector > subgraphs; + subgraphs.push_back(Ptr(new BatchNormSubgraph())); + subgraphs.push_back(Ptr(new BatchNormNoGammaSubgraph())); + subgraphs.push_back(Ptr(new FlattenSubgraph())); + subgraphs.push_back(Ptr(new FlattenShapeSubgraph())); + + int numNodes = net.node_size(); + int numMatchedNodes, numReplacedNodes; + for (int i = 0; i < numNodes; ++i) + { + for (int j = 0; j < subgraphs.size(); ++j) + { + if (subgraphs[j]->match(net, i, &numMatchedNodes)) + { + subgraphs[j]->replace(net, i, &numReplacedNodes); + numNodes -= numReplacedNodes; + break; + } + } + } +} + +void RemoveIdentityOps(tensorflow::GraphDef& net) +{ + typedef std::map IdentityOpsMap; + IdentityOpsMap identity_ops; + + std::vector identity_ops_idx; + + int layersCount = net.node_size(); + for (int li = 0; li < layersCount; li++) + { + const tensorflow::NodeDef &layer = net.node(li); + String type = layer.op(); + + if (type == "Identity" || type == "Dropout") { + identity_ops_idx.push_back(li); + identity_ops[layer.name()] = layer.input(0); + } + } + + for (int li = 0; li < layersCount; li++) + { + tensorflow::NodeDef* layer = net.mutable_node(li); + for (int input_id = 0; input_id < layer->input_size(); input_id++) { + String input_op_name = layer->input(input_id); + IdentityOpsMap::iterator it = identity_ops.find(input_op_name); + + if (it != identity_ops.end()) { + layer->set_input(input_id, it->second); + } + } + } + + std::sort(identity_ops_idx.begin(), identity_ops_idx.end()); + + int removed_nodes = 0; + for(size_t i = 0; i < identity_ops_idx.size(); i++) { + int start_id = identity_ops_idx[i] - removed_nodes; + net.mutable_node()->DeleteSubrange(start_id, 1); + removed_nodes++; + } +} + +Mat getTensorContent(const tensorflow::TensorProto &tensor) +{ + std::string content = tensor.tensor_content(); + switch (tensor.dtype()) + { + case tensorflow::DT_FLOAT: + { + if (!content.empty()) + return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone(); + else + { + const RepeatedField& field = tensor.float_val(); + CV_Assert(!field.empty()); + return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone(); + } + } + case tensorflow::DT_DOUBLE: + { + if (!content.empty()) + return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone(); + else + { + const RepeatedField& field = tensor.double_val(); + CV_Assert(!field.empty()); + return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone(); + } + } + case tensorflow::DT_INT32: + { + if (!content.empty()) + return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone(); + else + { + const RepeatedField& field = tensor.int_val(); + CV_Assert(!field.empty()); + return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone(); + } + } + case tensorflow::DT_HALF: + { + Mat halfs; + if (!content.empty()) + { + static const int kHalfSize = 2; + halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str()); + } + else + { + const RepeatedField& field = tensor.half_val(); + CV_Assert(!field.empty()); + Mat ints(1, field.size(), CV_32SC1, (void*)field.data()); + ints.convertTo(halfs, CV_16UC1); + } + // Reinterpret as a signed shorts just for a convertFp16 call. + Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data); + Mat floats(halfs.size(), CV_32FC1); + convertFp16(halfsSigned, floats); + return floats; + } + case tensorflow::DT_QUINT8: + { + CV_Assert(!content.empty()); + return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone(); + } + default: + CV_Error(Error::StsError, "Tensor's data type is not supported"); + break; + } + return Mat(); +} + +CV__DNN_EXPERIMENTAL_NS_END +}} // namespace dnn, namespace cv + +#endif // HAVE_PROTOBUF diff --git a/modules/dnn/src/tensorflow/tf_graph_editor.hpp b/modules/dnn/src/tensorflow/tf_graph_editor.hpp new file mode 100644 index 0000000000..5568c09b5e --- /dev/null +++ b/modules/dnn/src/tensorflow/tf_graph_editor.hpp @@ -0,0 +1,30 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// Copyright (C) 2018, Intel Corporation, all rights reserved. +// Third party copyrights are property of their respective owners. + +#ifndef __OPENCV_DNN_TF_SIMPLIFIER_HPP__ +#define __OPENCV_DNN_TF_SIMPLIFIER_HPP__ + +#include "../precomp.hpp" + +#ifdef HAVE_PROTOBUF + +#include "tf_io.hpp" + +namespace cv { namespace dnn { +CV__DNN_EXPERIMENTAL_NS_BEGIN + +void RemoveIdentityOps(tensorflow::GraphDef& net); + +void simplifySubgraphs(tensorflow::GraphDef& net); + +Mat getTensorContent(const tensorflow::TensorProto &tensor); + +CV__DNN_EXPERIMENTAL_NS_END +}} // namespace dnn, namespace cv + +#endif // HAVE_PROTOBUF +#endif // __OPENCV_DNN_TF_SIMPLIFIER_HPP__ diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 5309ec40ce..9be29b9c41 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -22,6 +22,7 @@ Implementation of Tensorflow models parser #include #include #include "tf_io.hpp" +#include "tf_graph_editor.hpp" #endif namespace cv { @@ -87,77 +88,6 @@ void blobShapeFromTensor(const tensorflow::TensorProto &tensor, MatShape& shape) } } -static Mat getTensorContent(const tensorflow::TensorProto &tensor) -{ - std::string content = tensor.tensor_content(); - switch (tensor.dtype()) - { - case tensorflow::DT_FLOAT: - { - if (!content.empty()) - return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone(); - else - { - const RepeatedField& field = tensor.float_val(); - CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone(); - } - } - case tensorflow::DT_DOUBLE: - { - if (!content.empty()) - return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone(); - else - { - const RepeatedField& field = tensor.double_val(); - CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone(); - } - } - case tensorflow::DT_INT32: - { - if (!content.empty()) - return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone(); - else - { - const RepeatedField& field = tensor.int_val(); - CV_Assert(!field.empty()); - return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone(); - } - } - case tensorflow::DT_HALF: - { - Mat halfs; - if (!content.empty()) - { - static const int kHalfSize = 2; - halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str()); - } - else - { - const RepeatedField& field = tensor.half_val(); - CV_Assert(!field.empty()); - Mat ints(1, field.size(), CV_32SC1, (void*)field.data()); - ints.convertTo(halfs, CV_16UC1); - } - // Reinterpret as a signed shorts just for a convertFp16 call. - Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data); - Mat floats(halfs.size(), CV_32FC1); - convertFp16(halfsSigned, floats); - return floats; - } - case tensorflow::DT_QUINT8: - { - CV_Assert(!content.empty()); - return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone(); - } - default: - CV_Error(Error::StsError, "Tensor's data type is not supported"); - break; - } - return Mat(); -} - template void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob) { @@ -364,47 +294,6 @@ void setPadding(LayerParams &layerParams, const tensorflow::NodeDef &layer) layerParams.set("pad_mode", getLayerAttr(layer, "padding").s()); } -void RemoveIdentityOps(tensorflow::GraphDef& net) { - typedef std::map IdentityOpsMap; - IdentityOpsMap identity_ops; - - std::vector identity_ops_idx; - - int layersCount = net.node_size(); - for (int li = 0; li < layersCount; li++) - { - const tensorflow::NodeDef &layer = net.node(li); - String type = layer.op(); - - if (type == "Identity" || type == "Dropout") { - identity_ops_idx.push_back(li); - identity_ops[layer.name()] = layer.input(0); - } - } - - for (int li = 0; li < layersCount; li++) - { - tensorflow::NodeDef* layer = net.mutable_node(li); - for (int input_id = 0; input_id < layer->input_size(); input_id++) { - String input_op_name = layer->input(input_id); - IdentityOpsMap::iterator it = identity_ops.find(input_op_name); - - if (it != identity_ops.end()) { - layer->set_input(input_id, it->second); - } - } - } - - std::sort(identity_ops_idx.begin(), identity_ops_idx.end()); - - int removed_nodes = 0; - for(size_t i = 0; i < identity_ops_idx.size(); i++) { - int start_id = identity_ops_idx[i] - removed_nodes; - net.mutable_node()->DeleteSubrange(start_id, 1); - removed_nodes++; - } -} - Pin parsePin(const std::string &name) { Pin pin(name); @@ -697,6 +586,9 @@ void TFImporter::populateNet(Net dstNet) RemoveIdentityOps(netBin); RemoveIdentityOps(netTxt); + if (!netTxt.ByteSize()) + simplifySubgraphs(netBin); + std::set layers_to_ignore; tensorflow::GraphDef& net = netTxt.ByteSize() != 0 ? netTxt : netBin; @@ -936,10 +828,28 @@ void TFImporter::populateNet(Net dstNet) connect(layer_id, dstNet, inpId, id, 0); data_layouts[name] = DATA_LAYOUT_UNKNOWN; } - else if (type == "Flatten") + else if (type == "Flatten" || type == "Squeeze") { Pin inpId = parsePin(layer.input(0)); - if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC) + int inpLayout = data_layouts[layer.input(0)]; + if (type == "Squeeze") + { + CV_Assert(hasLayerAttr(layer, "squeeze_dims")); + const tensorflow::AttrValue& dims = getLayerAttr(layer, "squeeze_dims"); + if (inpLayout == DATA_LAYOUT_NHWC) + { + if (dims.list().i_size() != 2 || dims.list().i(0) != 1 || dims.list().i(1) != 2) + CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration"); + } + else if (inpLayout == DATA_LAYOUT_NCHW) + { + if (dims.list().i_size() != 2 || dims.list().i(0) != 2 || dims.list().i(1) != 3) + CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration"); + } + else + CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration"); + } + if (inpLayout == DATA_LAYOUT_NHWC) { LayerParams permLP; int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC. @@ -1274,14 +1184,36 @@ void TFImporter::populateNet(Net dstNet) bool isTraining = hasLayerAttr(layer, "is_training") && getLayerAttr(layer, "is_training").b(); - layerParams.blobs.resize(4); - Mat gamma, beta, mean, std; - blobFromTensor(getConstBlob(layer, value_id, 1), gamma); - blobFromTensor(getConstBlob(layer, value_id, 2), beta); + layerParams.blobs.resize(2); + + const tensorflow::TensorProto& gammaTensor = getConstBlob(layer, value_id, 1); + if (!gammaTensor.tensor_content().empty()) + { + layerParams.blobs.resize(layerParams.blobs.size() + 1); + layerParams.set("has_weight", true); + blobFromTensor(gammaTensor, layerParams.blobs.back()); + } + else + layerParams.set("has_weight", false); + + const tensorflow::TensorProto& betaTensor = getConstBlob(layer, value_id, 2); + if (!betaTensor.tensor_content().empty()) + { + layerParams.blobs.resize(layerParams.blobs.size() + 1); + layerParams.set("has_bias", true); + blobFromTensor(betaTensor, layerParams.blobs.back()); + } + else + layerParams.set("has_bias", false); + + Mat mean, std; if (isTraining) { - mean = Mat::zeros(1, beta.total(), CV_32F); - std = Mat::ones(1, beta.total(), CV_32F); + if (layerParams.blobs.size() == 2) + CV_Error(Error::StsNotImplemented, "Cannot determine number " + "of parameters for batch normalization layer."); + mean = Mat::zeros(1, layerParams.blobs[3].total(), CV_32F); + std = Mat::ones(1, layerParams.blobs[3].total(), CV_32F); // Add an extra layer: Mean-Variance normalization LayerParams mvnParams; @@ -1299,15 +1231,10 @@ void TFImporter::populateNet(Net dstNet) } layerParams.blobs[0] = mean; layerParams.blobs[1] = std; - layerParams.blobs[2] = gamma; - layerParams.blobs[3] = beta; if (hasLayerAttr(layer, "epsilon")) layerParams.set("eps", getLayerAttr(layer, "epsilon").f()); - layerParams.set("has_weight", true); - layerParams.set("has_bias", true); - int id = dstNet.addLayer(name, "BatchNorm", layerParams); layer_id[name] = id; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index b3b995948d..b5c95673ca 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -150,6 +150,9 @@ TEST_P(Test_TensorFlow_layers, batch_norm) runTensorFlowNet("batch_norm_text", targetId, true); runTensorFlowNet("mvn_batch_norm", targetId); runTensorFlowNet("mvn_batch_norm_1x1", targetId); + runTensorFlowNet("unfused_batch_norm", targetId); + runTensorFlowNet("fused_batch_norm_no_gamma", targetId); + runTensorFlowNet("unfused_batch_norm_no_gamma", targetId); } TEST_P(Test_TensorFlow_layers, pooling) @@ -185,6 +188,8 @@ TEST_P(Test_TensorFlow_layers, reshape) runTensorFlowNet("shift_reshape_no_reorder", targetId); runTensorFlowNet("reshape_reduce", targetId); runTensorFlowNet("flatten", targetId, true); + runTensorFlowNet("unfused_flatten", targetId); + runTensorFlowNet("unfused_flatten_unknown_batch", targetId); } INSTANTIATE_TEST_CASE_P(/**/, Test_TensorFlow_layers, availableDnnTargets());