diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index e546d9e1da..bdbe7d1a15 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1129,15 +1129,14 @@ void TFImporter::populateNet(Net dstNet) if (value_id.find(layer.input(1)) != value_id.end()) { Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1)); - + if (newShape.total() == 4) + { + // NHWC->NCHW + std::swap(*newShape.ptr(0, 2), *newShape.ptr(0, 3)); + std::swap(*newShape.ptr(0, 1), *newShape.ptr(0, 2)); + } if (inpLayout == DATA_LAYOUT_NHWC) { - if (newShape.total() == 4) - { - // NHWC->NCHW - std::swap(*newShape.ptr(0, 2), *newShape.ptr(0, 3)); - std::swap(*newShape.ptr(0, 1), *newShape.ptr(0, 2)); - } if (newShape.total() != 4 || newShape.at(1) == 1) { LayerParams permLP; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index fa98e745f5..dac55d60b0 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -279,7 +279,7 @@ TEST_P(Test_TensorFlow_layers, matmul) // Reference output values are in range [-5.688, 4.484] double l1 = target == DNN_TARGET_MYRIAD ? 6.1e-3 : default_l1; runTensorFlowNet("nhwc_reshape_matmul", false, l1); - + runTensorFlowNet("matmul_layout"); } TEST_P(Test_TensorFlow_layers, reshape)