diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index d1b5a85d6c..986225a8c6 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -376,6 +376,8 @@ private: // and may be used to build the network using binary format only as a weights storage. // This approach is similar to Caffe's `.prorotxt` and `.caffemodel`. tensorflow::GraphDef netTxt; + + std::vector netInputsNames; }; TFImporter::TFImporter(const char *model, const char *config) @@ -443,7 +445,14 @@ void TFImporter::connect(const std::map& layers_name_id_map, Net& n std::map::const_iterator it = layers_name_id_map.find(outPin.name); if (it == layers_name_id_map.end()) CV_Error(Error::StsError, "Input layer not found: " + outPin.name); - network.connect(it->second, outPin.blobIndex, input_layer_id, input_blob_id); + + std::vector::iterator inpNameIt = std::find(netInputsNames.begin(), netInputsNames.end(), outPin.name); + int blobIndex; + if (inpNameIt == netInputsNames.end()) + blobIndex = outPin.blobIndex; + else + blobIndex = inpNameIt - netInputsNames.begin(); + network.connect(it->second, blobIndex, input_layer_id, input_blob_id); } void TFImporter::connectToAllBlobs(const std::map& layer_id, Net& network, const Pin& outPin, @@ -845,7 +854,7 @@ void TFImporter::populateNet(Net dstNet) Pin inp = parsePin(layer.input(ii)); if (layer_id.find(inp.name) == layer_id.end()) CV_Error(Error::StsError, "Input layer not found: " + inp.name); - dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii); + connect(layer_id, dstNet, inp, id, ii); } } } @@ -1099,7 +1108,7 @@ void TFImporter::populateNet(Net dstNet) Pin inp = parsePin(layer.input(ii)); if (layer_id.find(inp.name) == layer_id.end()) CV_Error(Error::StsError, "Input layer not found: " + inp.name); - dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii - from); + connect(layer_id, dstNet, inp, id, ii - from); } } else if (type == "MaxPool") @@ -1131,10 +1140,12 @@ void TFImporter::populateNet(Net dstNet) } else if (type == "Placeholder") { - std::vector netInputs(1); - netInputs[0] = name; - layer_id[name] = 0; - dstNet.setInputsNames(netInputs); + if (!hasLayerAttr(layer, "dtype") || + getLayerAttr(layer, "dtype").type() != tensorflow::DT_BOOL) // If input is not a train/test flag. + { + netInputsNames.push_back(name); + layer_id[name] = 0; + } } else if (type == "Split") { // TODO: determining axis index remapping by input dimensions order of input blob @@ -1272,7 +1283,7 @@ void TFImporter::populateNet(Net dstNet) Pin inp = parsePin(layer.input(ii)); if (layer_id.find(inp.name) == layer_id.end()) CV_Error(Error::StsError, "Input layer not found: " + inp.name); - dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii); + connect(layer_id, dstNet, inp, id, ii); } } } @@ -1790,6 +1801,7 @@ void TFImporter::populateNet(Net dstNet) } } } + dstNet.setInputsNames(netInputsNames); } } // namespace diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 33238c718e..747fefd913 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -441,4 +441,20 @@ TEST(Test_TensorFlow, resize_bilinear) runTensorFlowNet("resize_bilinear_factor"); } +TEST(Test_TensorFlow, two_inputs) +{ + Net net = readNet(path("two_inputs_net.pbtxt")); + net.setPreferableBackend(DNN_BACKEND_OPENCV); + + Mat firstInput(2, 3, CV_32FC1), secondInput(2, 3, CV_32FC1); + randu(firstInput, -1, 1); + randu(secondInput, -1, 1); + + net.setInput(firstInput, "first_input"); + net.setInput(secondInput, "second_input"); + Mat out = net.forward(); + + normAssert(out, firstInput + secondInput); +} + }