diff --git a/modules/dnn/src/layers/reshape_layer.cpp b/modules/dnn/src/layers/reshape_layer.cpp index 65a81c7820..c9e632dd29 100644 --- a/modules/dnn/src/layers/reshape_layer.cpp +++ b/modules/dnn/src/layers/reshape_layer.cpp @@ -82,17 +82,26 @@ static void computeShapeByReshapeMask(const MatShape &srcShape, { if (matched) { - if (i == 0 || total(srcShape, i, srcRange.end) != maskTotal) + if (total(srcShape, i, srcRange.end) != maskTotal) { srcRange.start = i + 1; break; } + else if (i == 0) + { + srcRange.start = 0; + break; + } } else { matched = total(srcShape, i, srcRange.end) == maskTotal; } } + while (total(srcShape, srcRange.start, srcRange.end) != maskTotal && srcRange.start > 0) + { + srcRange.start -= 1; + } CV_Assert(total(srcShape, srcRange.start, srcRange.end) == maskTotal); } diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 1faa7fba4d..7d7d300386 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -262,6 +262,18 @@ static int getDataLayout(const tensorflow::NodeDef& layer) return DATA_LAYOUT_UNKNOWN; } +static inline std::string getNodeName(const std::string& tensorName) +{ + return tensorName.substr(0, tensorName.rfind(':')); +} + +static inline int getDataLayout(const std::string& layerName, + const std::map& data_layouts) +{ + std::map::const_iterator it = data_layouts.find(getNodeName(layerName)); + return it != data_layouts.end() ? it->second : DATA_LAYOUT_UNKNOWN; +} + void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer) { if (hasLayerAttr(layer, "strides")) @@ -604,11 +616,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map& cons } } -static inline std::string getNodeName(const std::string& tensorName) -{ - return tensorName.substr(0, tensorName.rfind(':')); -} - // If all inputs of specific layer have the same data layout we can say that // this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise. static int predictOutputDataLayout(const tensorflow::GraphDef& net, @@ -830,7 +837,8 @@ void TFImporter::populateNet(Net dstNet) // one input only connect(layer_id, dstNet, parsePin(input), id, 0); - if (data_layouts[name] == DATA_LAYOUT_UNKNOWN) + + if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN) data_layouts[name] = DATA_LAYOUT_NHWC; } else if (type == "BiasAdd" || type == "Add") @@ -956,7 +964,8 @@ void TFImporter::populateNet(Net dstNet) Pin inpId = parsePin(layer.input(0)); Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1)); - if (newShape.total() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC) + int inpLayout = getDataLayout(layer.input(0), data_layouts); + if (newShape.total() != 4 && inpLayout == DATA_LAYOUT_NHWC) { LayerParams permLP; int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC. @@ -969,7 +978,7 @@ void TFImporter::populateNet(Net dstNet) connect(layer_id, dstNet, inpId, permId, 0); inpId = Pin(permName); } - else if (newShape.total() == 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC) + else if (newShape.total() == 4 && inpLayout == DATA_LAYOUT_NHWC) { // NHWC->NCHW std::swap(*newShape.ptr(0, 2), *newShape.ptr(0, 3)); @@ -987,7 +996,7 @@ void TFImporter::populateNet(Net dstNet) else if (type == "Flatten" || type == "Squeeze") { Pin inpId = parsePin(layer.input(0)); - int inpLayout = data_layouts[layer.input(0)]; + int inpLayout = getDataLayout(layer.input(0), data_layouts); if (type == "Squeeze") { CV_Assert(hasLayerAttr(layer, "squeeze_dims")); @@ -1032,7 +1041,8 @@ void TFImporter::populateNet(Net dstNet) { // Only NHWC <-> NCHW permutations are allowed. OpenCV is always // keep NCHW layout this way. - if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC) + int inpLayout = getDataLayout(layer.input(0), data_layouts); + if (inpLayout == DATA_LAYOUT_NHWC) { if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2) { @@ -1049,7 +1059,7 @@ void TFImporter::populateNet(Net dstNet) else CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed."); } - else if (data_layouts[layer.input(0)] == DATA_LAYOUT_NCHW) + else if (inpLayout == DATA_LAYOUT_NCHW) { if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1) { @@ -1112,7 +1122,7 @@ void TFImporter::populateNet(Net dstNet) int axisId = (type == "Concat" ? 0 : layer.input_size() - 1); int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0); - if (data_layouts[name] == DATA_LAYOUT_NHWC) + if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC) axis = toNCHW(axis); layerParams.set("axis", axis); @@ -1197,7 +1207,7 @@ void TFImporter::populateNet(Net dstNet) CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1, sizes.type() == CV_32SC1); - if (begins.total() == 4 && data_layouts[name] == DATA_LAYOUT_NHWC) + if (begins.total() == 4 && getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC) { // Swap NHWC parameters' order to NCHW. std::swap(*begins.ptr(0, 2), *begins.ptr(0, 3)); @@ -1597,7 +1607,7 @@ void TFImporter::populateNet(Net dstNet) CV_Assert(reductionIndices.type() == CV_32SC1); const int numAxes = reductionIndices.total(); - if (data_layouts[name] == DATA_LAYOUT_NHWC) + if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC) for (int i = 0; i < numAxes; ++i) reductionIndices.at(i) = toNCHW(reductionIndices.at(i)); diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index 3607e6c08e..88779e9977 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -592,8 +592,8 @@ struct TorchImporter DictValue dimParam = scalarParams.get("size"); layerParams.set("dim", dimParam); - if (scalarParams.has("batchMode") && scalarParams.get("batchMode")) - layerParams.set("axis", 1); + int axis = (int)scalarParams.get("batchMode", true); + layerParams.set("axis", axis); curModule->modules.push_back(newModule); } diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp index 720447afb9..963206bd73 100644 --- a/modules/dnn/test/test_layers.cpp +++ b/modules/dnn/test/test_layers.cpp @@ -201,6 +201,13 @@ TEST(Layer_Test_Reshape, Accuracy) testReshape(MatShape(inp, inp + 4), MatShape(out, out + 2), 0, -1, MatShape(mask, mask + 2)); } + { + int inp[] = {1, 2, 3}; + int out[] = {3, 1, 2}; + int mask[] = {3, 1, 2}; + testReshape(MatShape(inp, inp + 3), MatShape(out, out + 3), 0, -1, + MatShape(mask, mask + 3)); + } } TEST(Layer_Test_BatchNorm, Accuracy) diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index d4ffc94399..408782233c 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -198,6 +198,7 @@ TEST_P(Test_TensorFlow_layers, reshape) { int targetId = GetParam(); runTensorFlowNet("shift_reshape_no_reorder", targetId); + runTensorFlowNet("reshape_no_reorder", targetId); runTensorFlowNet("reshape_reduce", targetId); runTensorFlowNet("flatten", targetId, true); runTensorFlowNet("unfused_flatten", targetId);