diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 986225a8c6..6a4a0ab225 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -246,16 +246,41 @@ const tensorflow::AttrValue& getLayerAttr(const tensorflow::NodeDef &layer, cons return layer.attr().at(name); } +static int getDataLayout(const tensorflow::NodeDef& layer) +{ + if (hasLayerAttr(layer, "data_format")) + { + std::string format = getLayerAttr(layer, "data_format").s(); + if (format == "NHWC" || format == "channels_last") + return DATA_LAYOUT_NHWC; + else if (format == "NCHW" || format == "channels_first") + return DATA_LAYOUT_NCHW; + else + CV_Error(Error::StsParseError, "Unknown data_format value: " + format); + } + return DATA_LAYOUT_UNKNOWN; +} + void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer) { if (hasLayerAttr(layer, "strides")) { const tensorflow::AttrValue& val = getLayerAttr(layer, "strides"); + int dimX, dimY, dimC; + int layout = getDataLayout(layer); + if (layout == DATA_LAYOUT_NCHW) + { + dimC = 1; dimY = 2; dimX = 3; + } + else + { + dimY = 1; dimX = 2; dimC = 3; + } if (val.list().i_size() != 4 || - val.list().i(0) != 1 || val.list().i(3) != 1) + val.list().i(0) != 1 || val.list().i(dimC) != 1) CV_Error(Error::StsError, "Unsupported strides"); - layerParams.set("stride_h", static_cast(val.list().i(1))); - layerParams.set("stride_w", static_cast(val.list().i(2))); + layerParams.set("stride_h", static_cast(val.list().i(dimY))); + layerParams.set("stride_w", static_cast(val.list().i(dimX))); } } @@ -278,11 +303,21 @@ void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer) if (hasLayerAttr(layer, "ksize")) { const tensorflow::AttrValue& val = getLayerAttr(layer, "ksize"); + int dimX, dimY, dimC; + int layout = getDataLayout(layer); + if (layout == DATA_LAYOUT_NCHW) + { + dimC = 1; dimY = 2; dimX = 3; + } + else + { + dimY = 1; dimX = 2; dimC = 3; + } if (val.list().i_size() != 4 || - val.list().i(0) != 1 || val.list().i(3) != 1) + val.list().i(0) != 1 || val.list().i(dimC) != 1) CV_Error(Error::StsError, "Unsupported ksize"); - layerParams.set("kernel_h", static_cast(val.list().i(1))); - layerParams.set("kernel_w", static_cast(val.list().i(2))); + layerParams.set("kernel_h", static_cast(val.list().i(dimY))); + layerParams.set("kernel_w", static_cast(val.list().i(dimX))); } else { @@ -568,21 +603,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map& cons } } -static int getDataLayout(const tensorflow::NodeDef& layer) -{ - if (hasLayerAttr(layer, "data_format")) - { - std::string format = getLayerAttr(layer, "data_format").s(); - if (format == "NHWC" || format == "channels_last") - return DATA_LAYOUT_NHWC; - else if (format == "NCHW" || format == "channels_first") - return DATA_LAYOUT_NCHW; - else - CV_Error(Error::StsParseError, "Unknown data_format value: " + format); - } - return DATA_LAYOUT_UNKNOWN; -} - static inline std::string getNodeName(const std::string& tensorName) { return tensorName.substr(0, tensorName.rfind(':')); diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 747fefd913..d4ffc94399 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -127,6 +127,7 @@ TEST_P(Test_TensorFlow_layers, conv) runTensorFlowNet("atrous_conv2d_same", targetId); runTensorFlowNet("depthwise_conv2d", targetId); runTensorFlowNet("keras_atrous_conv2d_same", targetId); + runTensorFlowNet("conv_pool_nchw", targetId); } TEST_P(Test_TensorFlow_layers, padding)