From 7ac7aca33b082f5cecc40479bb704cddff540df6 Mon Sep 17 00:00:00 2001 From: Alexander Alekhin Date: Wed, 19 Feb 2020 07:08:01 +0000 Subject: [PATCH] dnn(caffe): fix net.input_dim handling in Caffe importer --- modules/dnn/src/caffe/caffe_importer.cpp | 98 ++++++++++++++++-------- 1 file changed, 67 insertions(+), 31 deletions(-) diff --git a/modules/dnn/src/caffe/caffe_importer.cpp b/modules/dnn/src/caffe/caffe_importer.cpp index a911f959d9..7d3d15b1b8 100644 --- a/modules/dnn/src/caffe/caffe_importer.cpp +++ b/modules/dnn/src/caffe/caffe_importer.cpp @@ -75,6 +75,17 @@ static cv::String toString(const T &v) return ss.str(); } +static inline +MatShape parseBlobShape(const caffe::BlobShape& _input_shape) +{ + MatShape shape; + for (int i = 0; i < _input_shape.dim_size(); i++) + { + shape.push_back((int)_input_shape.dim(i)); + } + return shape; +} + class CaffeImporter { caffe::NetParameter net; @@ -235,10 +246,7 @@ public: } else if (pbBlob.has_shape()) { - const caffe::BlobShape &_shape = pbBlob.shape(); - - for (int i = 0; i < _shape.dim_size(); i++) - shape.push_back((int)_shape.dim(i)); + shape = parseBlobShape(pbBlob.shape()); } else shape.resize(1, 1); // Is a scalar. @@ -334,12 +342,49 @@ public: //setup input layer names std::vector netInputs(net.input_size()); + std::vector inp_shapes; { - for (int inNum = 0; inNum < net.input_size(); inNum++) + int net_input_size = net.input_size(); + for (int inNum = 0; inNum < net_input_size; inNum++) { addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum)); netInputs[inNum] = net.input(inNum); } + + if (net.input_dim_size() > 0) // deprecated in Caffe proto + { + int net_input_dim_size = net.input_dim_size(); + CV_Check(net_input_dim_size, net_input_dim_size % 4 == 0, ""); + CV_CheckEQ(net_input_dim_size, net_input_size * 4, ""); + for (int inp_id = 0; inp_id < net_input_size; inp_id++) + { + int dim = inp_id * 4; + MatShape shape(4); + shape[0] = net.input_dim(dim); + shape[1] = net.input_dim(dim+1); + shape[2] = net.input_dim(dim+2); + shape[3] = net.input_dim(dim+3); + inp_shapes.push_back(shape); + } + } + else if (net.input_shape_size() > 0) // deprecated in Caffe proto + { + int net_input_shape_size = net.input_shape_size(); + CV_CheckEQ(net_input_shape_size, net_input_size, ""); + for (int inp_id = 0; inp_id < net_input_shape_size; inp_id++) + { + MatShape shape = parseBlobShape(net.input_shape(inp_id)); + inp_shapes.push_back(shape); + } + } + else + { + for (int inp_id = 0; inp_id < net_input_size; inp_id++) + { + MatShape shape; // empty + inp_shapes.push_back(shape); + } + } } for (int li = 0; li < layersSize; li++) @@ -364,6 +409,17 @@ public: addedBlobs.back().outNum = netInputs.size(); netInputs.push_back(addedBlobs.back().name); } + if (layer.has_input_param()) + { + const caffe::InputParameter &inputParameter = layer.input_param(); + int input_shape_size = inputParameter.shape_size(); + CV_CheckEQ(input_shape_size, layer.top_size(), ""); + for (int inp_id = 0; inp_id < input_shape_size; inp_id++) + { + MatShape shape = parseBlobShape(inputParameter.shape(inp_id)); + inp_shapes.push_back(shape); + } + } continue; } else if (type == "BatchNorm") @@ -424,35 +480,15 @@ public: } dstNet.setInputsNames(netInputs); - std::vector inp_shapes; - if (net.input_shape_size() > 0 || (layersSize > 0 && net.layer(0).has_input_param() && - net.layer(0).input_param().shape_size() > 0)) { - - int size = (net.input_shape_size() > 0) ? net.input_shape_size() : - net.layer(0).input_param().shape_size(); - for (int inp_id = 0; inp_id < size; inp_id++) + if (inp_shapes.size() > 0) + { + CV_CheckEQ(inp_shapes.size(), netInputs.size(), ""); + for (int inp_id = 0; inp_id < inp_shapes.size(); inp_id++) { - const caffe::BlobShape &_input_shape = (net.input_shape_size() > 0) ? - net.input_shape(inp_id) : - net.layer(0).input_param().shape(inp_id); - MatShape shape; - for (int i = 0; i < _input_shape.dim_size(); i++) { - shape.push_back((int)_input_shape.dim(i)); - } - inp_shapes.push_back(shape); + if (!inp_shapes[inp_id].empty()) + dstNet.setInput(Mat(inp_shapes[inp_id], CV_32F), netInputs[inp_id]); } } - else if (net.input_dim_size() > 0) { - MatShape shape; - for (int dim = 0; dim < net.input_dim_size(); dim++) { - shape.push_back(net.input_dim(dim)); - } - inp_shapes.push_back(shape); - } - - for (int inp_id = 0; inp_id < inp_shapes.size(); inp_id++) { - dstNet.setInput(Mat(inp_shapes[inp_id], CV_32F), netInputs[inp_id]); - } addedBlobs.clear(); }