From b4a6aa335d8c793397529570dc1c5b5c7be578ab Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Wed, 13 May 2020 23:51:52 +0300 Subject: [PATCH] TensorFlow bilinear resize downscale --- modules/dnn/src/layers/resize_layer.cpp | 6 +-- .../src/tensorflow/tf_graph_simplifier.cpp | 51 ++++++++++++++++++- modules/dnn/src/tensorflow/tf_importer.cpp | 8 +-- modules/dnn/test/test_tf_importer.cpp | 1 + 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/modules/dnn/src/layers/resize_layer.cpp b/modules/dnn/src/layers/resize_layer.cpp index c86fa7f717..09e68eee47 100644 --- a/modules/dnn/src/layers/resize_layer.cpp +++ b/modules/dnn/src/layers/resize_layer.cpp @@ -19,8 +19,8 @@ namespace cv { namespace dnn { class ResizeLayerImpl : public ResizeLayer { public: - ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get("zoom_factor_x", params.get("zoom_factor", 0))), - zoomFactorHeight(params.get("zoom_factor_y", params.get("zoom_factor", 0))), + ResizeLayerImpl(const LayerParams& params) : zoomFactorWidth(params.get("zoom_factor_x", params.get("zoom_factor", 0))), + zoomFactorHeight(params.get("zoom_factor_y", params.get("zoom_factor", 0))), scaleWidth(0), scaleHeight(0) { setParamsFrom(params); @@ -223,7 +223,7 @@ public: protected: int outWidth, outHeight; - const int zoomFactorWidth, zoomFactorHeight; + const float zoomFactorWidth, zoomFactorHeight; String interpolation; float scaleWidth, scaleHeight; bool alignCorners; diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 1afed2cf46..99b3d7ac2f 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -495,8 +495,9 @@ public: ResizeBilinearSubgraph() { int input = addNodeToMatch(""); + int shapeSource = addNodeToMatch(""); - int shape = addNodeToMatch("Shape", input); + int shape = addNodeToMatch("Shape", shapeSource); int stack = addNodeToMatch("Const"); int stack_1 = addNodeToMatch("Const"); int stack_2 = addNodeToMatch("Const"); @@ -504,7 +505,7 @@ public: int factorY = addNodeToMatch("Const"); int mul = addNodeToMatch("Mul", strided_slice, factorY); - shape = addNodeToMatch("Shape", input); + shape = addNodeToMatch("Shape", shapeSource); stack = addNodeToMatch("Const"); stack_1 = addNodeToMatch("Const"); stack_2 = addNodeToMatch("Const"); @@ -519,6 +520,51 @@ public: } }; +// In case of resizing by factor. +class ResizeBilinearSubgraphDown : public TFSubgraph +{ +public: + ResizeBilinearSubgraphDown() + { + int input = addNodeToMatch(""); + int shapeSource = addNodeToMatch(""); + + int shape = addNodeToMatch("Shape", shapeSource); + int stack = addNodeToMatch("Const"); + int stack_1 = addNodeToMatch("Const"); + int stack_2 = addNodeToMatch("Const"); + int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + int factorY = addNodeToMatch("Const"); + int div = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorY); + int cast = addNodeToMatch("Cast", div); + + shape = addNodeToMatch("Shape", shapeSource); + stack = addNodeToMatch("Const"); + stack_1 = addNodeToMatch("Const"); + stack_2 = addNodeToMatch("Const"); + strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + int factorX = addNodeToMatch("Const"); + int div_1 = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorX); + int cast_1 = addNodeToMatch("Cast", div_1); + + int pack = addNodeToMatch("Pack", cast, cast_1); + + addNodeToMatch("ResizeBilinear", input, pack); + setFusedNode("ResizeBilinear", input, factorY, factorX); + } + + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, + std::vector& inputNodes) CV_OVERRIDE + { + + for (int i = 1; i < 3; ++i) + { + tensorflow::TensorProto* factor = inputNodes[i]->mutable_attr()->at("value").mutable_tensor(); + factor->set_double_val(0, 1.0 / factor->double_val(0)); + } + } +}; + // In case of resizing by factor. class UpsamplingKerasSubgraph : public TFSubgraph { @@ -702,6 +748,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new PReLUSubgraph(true))); subgraphs.push_back(Ptr(new PReLUSubgraph(false))); subgraphs.push_back(Ptr(new FlattenProdSubgraph())); + subgraphs.push_back(Ptr(new ResizeBilinearSubgraphDown())); for (int i = 0; i < net.node_size(); ++i) { diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 9fd611fd0a..e684b94e46 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1932,10 +1932,10 @@ void TFImporter::populateNet(Net dstNet) { Mat factorHeight = getTensorContent(getConstBlob(layer, value_id, 1)); Mat factorWidth = getTensorContent(getConstBlob(layer, value_id, 2)); - CV_CheckTypeEQ(factorHeight.type(), CV_32SC1, ""); CV_CheckEQ(factorHeight.total(), (size_t)1, ""); - CV_CheckTypeEQ(factorWidth.type(), CV_32SC1, ""); CV_CheckEQ(factorWidth.total(), (size_t)1, ""); - layerParams.set("zoom_factor_x", factorWidth.at(0)); - layerParams.set("zoom_factor_y", factorHeight.at(0)); + factorHeight.convertTo(factorHeight, CV_32F); + factorWidth.convertTo(factorWidth, CV_32F); + layerParams.set("zoom_factor_x", factorWidth.at(0)); + layerParams.set("zoom_factor_y", factorHeight.at(0)); } else CV_Assert(layer.input_size() == 2 || layer.input_size() == 3); diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 642b5158b1..b20b2a58ff 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -969,6 +969,7 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear) { runTensorFlowNet("resize_bilinear"); runTensorFlowNet("resize_bilinear_factor"); + runTensorFlowNet("resize_bilinear_down"); } TEST_P(Test_TensorFlow_layers, tf2_dense)