From 7e4b2057f27746863be51507b8187b491da42668 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Wed, 25 Mar 2020 15:34:28 +0300 Subject: [PATCH] Import TF2.0 network from Keras --- .../dnn/src/tensorflow/tf_graph_simplifier.cpp | 15 ++++++++++++--- modules/dnn/test/test_tf_importer.cpp | 5 +++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 2e7bb574e9..b0978c2ace 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -682,6 +682,15 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) IdentityOpsMap::iterator it = identity_ops.find(input_op_name); if (it != identity_ops.end()) { + // In case of Identity after Identity + while (true) + { + IdentityOpsMap::iterator nextIt = identity_ops.find(it->second); + if (nextIt != identity_ops.end()) + it = nextIt; + else + break; + } layer->set_input(input_id, it->second); } } @@ -847,7 +856,7 @@ void sortByExecutionOrder(tensorflow::GraphDef& net) nodesToAdd.push_back(i); else { - if (node.op() == "Merge" || node.op() == "RefMerge") + if (node.op() == "Merge" || node.op() == "RefMerge" || node.op() == "NoOp") { int numControlEdges = 0; for (int j = 0; j < numInputsInGraph; ++j) @@ -896,7 +905,7 @@ void removePhaseSwitches(tensorflow::GraphDef& net) { const tensorflow::NodeDef& node = net.node(i); nodesMap.insert(std::make_pair(node.name(), i)); - if (node.op() == "Switch" || node.op() == "Merge") + if (node.op() == "Switch" || node.op() == "Merge" || node.op() == "NoOp") { CV_Assert(node.input_size() > 0); // Replace consumers' inputs. @@ -914,7 +923,7 @@ void removePhaseSwitches(tensorflow::GraphDef& net) } } nodesToRemove.push_back(i); - if (node.op() == "Merge" || node.op() == "Switch") + if (node.op() == "Merge" || node.op() == "Switch" || node.op() == "NoOp") mergeOpSubgraphNodes.push(i); } } diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 8cacae8ea8..0088cfdd92 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -867,6 +867,11 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear) runTensorFlowNet("resize_bilinear_factor"); } +TEST_P(Test_TensorFlow_layers, tf2_keras) +{ + runTensorFlowNet("tf2_dense"); +} + TEST_P(Test_TensorFlow_layers, squeeze) { #if defined(INF_ENGINE_RELEASE)