diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index cfb472ec00..80bb6b5367 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -419,6 +419,125 @@ public: } }; +class DeconvolutionValidKerasSubgraph : public Subgraph +{ +public: + DeconvolutionValidKerasSubgraph() + { + int input = addNodeToMatch(""); + int shape = addNodeToMatch("Shape", input); + int kernel = addNodeToMatch("Const"); + + int stack = addNodeToMatch("Const"); + int stack_1 = addNodeToMatch("Const"); + int stack_2 = addNodeToMatch("Const"); + int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + + stack = addNodeToMatch("Const"); + stack_1 = addNodeToMatch("Const"); + stack_2 = addNodeToMatch("Const"); + int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + + stack = addNodeToMatch("Const"); + stack_1 = addNodeToMatch("Const"); + stack_2 = addNodeToMatch("Const"); + int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + + int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const")); + int add = addNodeToMatch("Add", mul, addNodeToMatch("Const")); + + int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const")); + int add_1 = addNodeToMatch("Add", mul_1, addNodeToMatch("Const")); + int pack = addNodeToMatch("Pack", strided_slice, add, add_1, addNodeToMatch("Const")); + addNodeToMatch("Conv2DBackpropInput", pack, kernel, input); + // Put any unused Const op to the first input. + setFusedNode("Conv2DBackpropInput", stack, kernel, input); + } + + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, + std::vector& inputNodes) CV_OVERRIDE + { + // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp) + // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX; + // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY; + // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW). + std::string padMode = fusedNode->attr().at("padding").s(); + CV_Assert(padMode == "VALID"); + + const tensorflow::TensorShapeProto& kernelShape = + inputNodes[1]->mutable_attr()->at("value").tensor().tensor_shape(); + + CV_Assert(kernelShape.dim_size() == 4); + const int kernelHeight = kernelShape.dim(0).size(); + const int kernelWidth = kernelShape.dim(1).size(); + + tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor(); + outShape->clear_int_val(); + outShape->add_int_val(-1); + outShape->add_int_val(kernelHeight); + outShape->add_int_val(kernelWidth); + outShape->add_int_val(-1); + } +}; + +class DeconvolutionSameKerasSubgraph : public Subgraph +{ +public: + DeconvolutionSameKerasSubgraph() + { + int input = addNodeToMatch(""); + int shape = addNodeToMatch("Shape", input); + int kernel = addNodeToMatch("Const"); + + int stack = addNodeToMatch("Const"); + int stack_1 = addNodeToMatch("Const"); + int stack_2 = addNodeToMatch("Const"); + int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + + stack = addNodeToMatch("Const"); + stack_1 = addNodeToMatch("Const"); + stack_2 = addNodeToMatch("Const"); + int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + + stack = addNodeToMatch("Const"); + stack_1 = addNodeToMatch("Const"); + stack_2 = addNodeToMatch("Const"); + int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + + int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const")); + + int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const")); + int pack = addNodeToMatch("Pack", strided_slice, mul, mul_1, addNodeToMatch("Const")); + addNodeToMatch("Conv2DBackpropInput", pack, kernel, input); + // Put any unused Const op to the first input. + setFusedNode("Conv2DBackpropInput", stack, kernel, input); + } + + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, + std::vector& inputNodes) CV_OVERRIDE + { + // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp) + // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX; + // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY; + // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW). + std::string padMode = fusedNode->attr().at("padding").s(); + CV_Assert(padMode == "SAME"); + + const tensorflow::AttrValue_ListValue& strides = fusedNode->attr().at("strides").list(); + CV_Assert(strides.i_size() == 4); + + const int strideY = strides.i(1); + const int strideX = strides.i(2); + + tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor(); + outShape->clear_int_val(); + outShape->add_int_val(-1); + outShape->add_int_val(strideY); + outShape->add_int_val(strideX); + outShape->add_int_val(-1); + } +}; + void simplifySubgraphs(tensorflow::GraphDef& net) { std::vector > subgraphs; @@ -430,6 +549,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new ReLU6KerasSubgraph())); subgraphs.push_back(Ptr(new ReshapeKerasSubgraph(3))); subgraphs.push_back(Ptr(new L2NormalizeSubgraph())); + subgraphs.push_back(Ptr(new DeconvolutionValidKerasSubgraph())); + subgraphs.push_back(Ptr(new DeconvolutionSameKerasSubgraph())); int numNodes = net.node_size(); std::vector matchedNodesIds; diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index f5809163bd..a401f715be 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1303,8 +1303,8 @@ void TFImporter::populateNet(Net dstNet) const int strideY = layerParams.get("stride_h"); const int strideX = layerParams.get("stride_w"); Mat outShape = getTensorContent(getConstBlob(layer, value_id, 0)); - const int outH = outShape.at(2); - const int outW = outShape.at(1); + const int outH = outShape.at(1); + const int outW = outShape.at(2); if (layerParams.get("pad_mode") == "SAME") { layerParams.set("adj_w", (outW - 1) % strideX); diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 43e30ebe3d..b7f3c5c55a 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -173,6 +173,8 @@ TEST_P(Test_TensorFlow_layers, deconvolution) runTensorFlowNet("deconvolution_stride_2_same", targetId); runTensorFlowNet("deconvolution_adj_pad_valid", targetId); runTensorFlowNet("deconvolution_adj_pad_same", targetId); + runTensorFlowNet("keras_deconv_valid", targetId); + runTensorFlowNet("keras_deconv_same", targetId); } TEST_P(Test_TensorFlow_layers, matmul)